From 618a69e466d6324f2e02bb5a234630b5ce6702f5 Mon Sep 17 00:00:00 2001 From: aikanor Date: Mon, 1 Feb 2021 21:29:52 -0800 Subject: [PATCH 001/244] add rough cut models/data --- examples/models/CNN_genome.py | 209 +++++++++++++++++++++++++++ wilds/datasets/encodetfbs_dataset.py | 147 +++++++++++++++++++ 2 files changed, 356 insertions(+) create mode 100644 examples/models/CNN_genome.py create mode 100644 wilds/datasets/encodetfbs_dataset.py diff --git a/examples/models/CNN_genome.py b/examples/models/CNN_genome.py new file mode 100644 index 00000000..f0115322 --- /dev/null +++ b/examples/models/CNN_genome.py @@ -0,0 +1,209 @@ +import math +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Beagle(nn.Module): + """ + Neural net models over genomic sequence. + Input: + - sequence_length: int (default 1000) + - Shape: (N, 5, sequence_length, 1) with batch size N. + + Output: + - prediction (Tensor): float torch tensor of shape (N, ) + + TODO: Finish docstring. + """ + def __init__(self, args): + """ + Parameters + ---------- + sequence_length : int + n_genomic_features : int + """ + super(Beagle, self).__init__() + + self.dropout = args.dropout + self.num_cell_types = 1 + self.conv1 = nn.Conv2d(5, 300, (19, 1), stride = (1, 1), padding=(9,0)) + self.conv2 = nn.Conv2d(300, 200, (11, 1), stride = (1, 1), padding = (5,0)) + self.conv3 = nn.Conv2d(200, 200, (7, 1), stride = (1, 1), padding = (4,0)) + self.bn1 = nn.BatchNorm2d(300) + self.bn2 = nn.BatchNorm2d(200) + self.bn3 = nn.BatchNorm2d(200) + self.maxpool1 = nn.MaxPool2d((3, 1)) + self.maxpool2 = nn.MaxPool2d((4, 1)) + self.maxpool3 = nn.MaxPool2d((4, 1)) + + self.fc1 = nn.Linear(4200, 1000) + self.bn4 = nn.BatchNorm1d(1000) + + self.fc2 = nn.Linear(1000, 1000) + self.bn5 = nn.BatchNorm1d(1000) + + self.fc3 = nn.Linear(1000, self.num_cell_types) + + def forward(self, s): + s = s.permute(0, 2, 1).contiguous() # batch_size x 5 x 1000 + s = s.view(-1, 5, 1000, 1) # batch_size x 5 x 1000 x 1 [5 channels] + s = self.maxpool1(F.relu(self.bn1(self.conv1(s)))) # batch_size x 300 x 333 x 1 + s = self.maxpool2(F.relu(self.bn2(self.conv2(s)))) # batch_size x 200 x 83 x 1 + s = self.maxpool3(F.relu(self.bn3(self.conv3(s)))) # batch_size x 200 x 21 x 1 + s = s.view(-1, 4200) + conv_out = s + + s = F.dropout(F.relu(self.bn4(self.fc1(s))), p=self.dropout, training=self.training) # batch_size x 1000 + s = F.dropout(F.relu(self.bn5(self.fc2(s))), p=self.dropout, training=self.training) # batch_size x 1000 + + s = self.fc3(s) + + return s, conv_out + + + +#class MLP(nn.Module): +# """Just an MLP""" +# def __init__(self, n_inputs, n_outputs, width, depth, drop_out): +# super(MLP, self).__init__() +# +# self.input = nn.Linear(n_inputs, width) +# self.dropout = nn.Dropout(dropout) +# self.hiddens = nn.ModuleList([ +# nn.Linear(width,width) +# for _ in range(depth-2)]) +# self.output = nn.Linear(width, n_outputs) +# self.n_outputs = n_outputs +# +# def forward(self, x): +# x = self.input(x) +# x = self.dropout(x) +# x = F.relu(x) +# for hidden in self.hiddens: +# x = hidden(x) +# x = self.dropout(x) +# x = F.relu(x) +# x = self.output(x) +# return x + + +""" +DeepSEA architecture (Zhou & Troyanskaya, 2015). +Based on https://github.com/FunctionLab/selene/blob/master/models/deepsea.py +""" + +class DeepSEA(nn.Module): + def __init__(self, sequence_length, n_genomic_features): + """ + Parameters + ---------- + sequence_length : int + n_genomic_features : int + """ + super(DeepSEA, self).__init__() + conv_kernel_size = 8 + pool_kernel_size = 4 + + self.conv_net = nn.Sequential( + nn.Conv1d(4, 320, kernel_size=conv_kernel_size), + nn.ReLU(inplace=True), + nn.MaxPool1d( + kernel_size=pool_kernel_size, stride=pool_kernel_size), + nn.Dropout(p=0.2), + + nn.Conv1d(320, 480, kernel_size=conv_kernel_size), + nn.ReLU(inplace=True), + nn.MaxPool1d( + kernel_size=pool_kernel_size, stride=pool_kernel_size), + nn.Dropout(p=0.2), + + nn.Conv1d(480, 960, kernel_size=conv_kernel_size), + nn.ReLU(inplace=True), + nn.Dropout(p=0.5)) + + reduce_by = conv_kernel_size - 1 + pool_kernel_size = float(pool_kernel_size) + self.n_channels = int( + np.floor( + (np.floor( + (sequence_length - reduce_by) / pool_kernel_size) + - reduce_by) / pool_kernel_size) + - reduce_by) + self.classifier = nn.Sequential( + nn.Linear(960 * self.n_channels, n_genomic_features), + nn.ReLU(inplace=True), + nn.Linear(n_genomic_features, n_genomic_features), + nn.Sigmoid()) + + def forward(self, x): + """Forward propagation of a batch. + """ + out = self.conv_net(x) + reshape_out = out.view(out.size(0), 960 * self.n_channels) + predict = self.classifier(reshape_out) + return predict + +""" +def criterion(): + return nn.BCELoss() + +def get_optimizer(lr): + # The optimizer and the parameters with which to initialize the optimizer. At a later time, we initialize the optimizer by also passing in the model parameters (`model.parameters()`). We cannot initialize the optimizer until the model has been initialized. + return (torch.optim.SGD, {"lr": lr, "weight_decay": 1e-6, "momentum": 0.9}) +""" + + + +""" +DanQ architecture (Quang & Xie, 2016). +""" + +class DanQ(nn.Module): + def __init__(self, sequence_length, n_genomic_features): + """ + Parameters + ---------- + sequence_length : int + Input sequence length + n_genomic_features : int + Total number of features to predict + """ + super(DanQ, self).__init__() + self.nnet = nn.Sequential( + nn.Conv1d(4, 320, kernel_size=26), + nn.ReLU(inplace=True), + nn.MaxPool1d( + kernel_size=13, stride=13), + nn.Dropout(0.2)) + + self.bdlstm = nn.Sequential(nn.LSTM(320, 320, num_layers=1, batch_first=True, bidirectional=True)) + + self._n_channels = math.floor( + (sequence_length - 25) / 13) + self.classifier = nn.Sequential( + nn.Dropout(0.5), + nn.Linear(self._n_channels * 640, 925), + nn.ReLU(inplace=True), + nn.Linear(925, n_genomic_features), + nn.Sigmoid()) + + def forward(self, x): + """Forward propagation of a batch. + """ + out = self.nnet(x) + reshape_out = out.transpose(0, 1).transpose(0, 2) + out, _ = self.bdlstm(reshape_out) + out = out.transpose(0, 1) + reshape_out = out.contiguous().view( + out.size(0), 640 * self._n_channels) + predict = self.classifier(reshape_out) + return predict + +""" +def criterion(): + return nn.BCELoss() + +def get_optimizer(lr): + return (torch.optim.RMSprop, {"lr": lr}) +""" \ No newline at end of file diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py new file mode 100644 index 00000000..08276aa9 --- /dev/null +++ b/wilds/datasets/encodetfbs_dataset.py @@ -0,0 +1,147 @@ +import os +import torch +import pandas as pd +import numpy as np +from wilds.datasets.wilds_dataset import WILDSDataset +from wilds.common.grouper import CombinatorialGrouper +from wilds.common.metrics.eval_metric import Accuracy +from wilds.common.eval import standard_group_eval + +import IPython + +class EncodeTFBSDataset(WILDSDataset): + """ + EncodeTFBS dataset + Website: https://www.synapse.org/#!Synapse:syn6131484 + """ + + def __init__(self, root_dir, download, split_scheme): + self._dataset_name = 'encodeTFBS' + self._download_url = 'https://worksheets.codalab.org/rest/bundles/0x8b3255e21e164cd98d3aeec09cd0bc26/contents/blob/' + self._data_dir = self.initialize_data_dir(root_dir, download) + self._y_size = 1 + self._n_classes = 2 + + self._tr_chrs = ['chr2', 'chr9', 'chr11'] + self._te_chrs = ['chr1', 'chr8', 'chr21'] + self._transcription_factor = 'MAX' + self._train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562'] + self._val_celltype = ['A549'] + self._test_celltype = ['GM12878'] + self._all_celltypes = self._train_celltypes + self._val_celltype + self._test_celltype + + self._metadata_fields = ['chr', 'celltype', 'y'] + self._metadata_map = {} + self._metadata_map['chr'] = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX'] + self._metadata_map['celltype'] = self._all_celltypes + + # Load sequence and DNase features + sequence_filename = os.path.join(self._data_dir, 'sequence.npz') + seq_arr = np.load(sequence_filename) + self._seq_bp = {} + for chrom in seq_arr: + self._seq_bp[chrom] = seq_arr[chrom] + + self._dnase_allcelltypes = {} + for ct in self._all_celltypes: + dnase_filename = os.path.join(self._data_dir, '{}_dnase.npz'.format(ct)) + dnase_npz_file = np.load(dnase_filename) + self._dnase_allcelltypes[ct] = {} + for chrom in seq_bp: + self._dnase_allcelltypes[ct][chrom] = dnase_npz_file[chrom] + + # Read in metadata dataframe from training+validation data + train_chr = pd.read_csv(os.path.join(self._data_dir, 'labels/{}.train.labels.tsv.gz'.format(self._transcription_factor)), sep='\t') + val_chr = pd.read_csv(os.path.join(self._data_dir, 'labels/{}.val.labels.tsv.gz'.format(self._transcription_factor)), sep='\t') + training_df = train_chr[np.isin(train_chr['chr'], self._tr_chrs)] + val_df = val_chr[np.isin(val_chr['chr'], self._te_chrs)] + all_df = pd.concat([training_df, val_df]) + + # Filter by start/stop coordinate if needed + filter_msk = all_df['start'] >= 0 + filter_msk = all_df['start']%1000 == 0 + all_df = all_df[filter_msk] + + pd_list = [] + for ct in self._train_celltypes: + tc_chr = all_df[['chr', 'start', 'stop', ct]] + tc_chr.columns = ['chr', 'start', 'stop', 'y'] + tc_chr['celltype'] = ct + pd_list.append(tc_chr) + metadata_df = pd.concat(pd_list) + + # Get the y values, and remove ambiguous labels by default. + y_array = metadata_df['y'].replace({'U': 0, 'B': 1, 'A': -1}).values + non_ambig_mask = (y_array != -1) + metadata_df['y'] = y_array + self._metadata_df = metadata_df[non_ambig_mask] + self._y_array = torch.LongTensor(y_array[non_ambig_mask]) + + chr_ints = self._metadata_df['chr'].replace(dict( [(y, x) for x, y in enumerate(self._metadata_map['chr'])] )).values + celltype_ints = self._metadata_df['celltype'].replace(dict( [(y, x) for x, y in enumerate(self._metadata_map['celltype'])] )).values + self._metadata_array = torch.stack( + (torch.LongTensor(chr_ints), + torch.LongTensor(celltype_ints), + self._y_array), + dim=1) + + # Get the splits + # TODO Extract splits as encoded in split_scheme. Hardcoded here for now. + self._split_scheme = split_scheme + self._split_dict = { + 'train': 0, + 'val-id': 1, + 'test': 2, + 'val-ood': 3 + } + self._split_names = { + 'train': 'Train', + 'val-id': 'Validation (ID)', + 'test': 'Test', + 'val-ood': 'Validation (OOD)', + } + train_chr_mask = np.isin(self._metadata_df['chr'], self._tr_chrs) + val_chr_mask = np.isin(self._metadata_df['chr'], self._te_chrs) + train_celltype_mask = np.isin(self._metadata_df['celltype'], self._train_celltypes) + val_celltype_mask = np.isin(self._metadata_df['celltype'], self._val_celltype) + test_celltype_mask = np.isin(self._metadata_df['celltype'], self._test_celltype) + + split_array = -1*np.ones(self._metadata_df.shape[0]).astype(int) + split_array[np.logical_and(train_chr_mask, train_celltype_mask)] = self._split_dict['train'] + split_array[np.logical_and(val_chr_mask, test_celltype_mask)] = self._split_dict['test'] + # Validate using test chr, either using a designated validation cell line ('val-ood') or a training cell line ('val-id') + split_array[np.logical_and(val_chr_mask, val_celltype_mask)] = self._split_dict['val-ood'] + split_array[np.logical_and(val_chr_mask, train_celltype_mask)] = self._split_dict['val-id'] + if self._split_scheme=='standard': + self._metadata_df['split'] = split_array + self._split_array = split_array + else: + raise ValueError(f'Split scheme {self._split_scheme} not recognized') + self._eval_grouper = CombinatorialGrouper( + dataset=self, + groupby_fields=['celltype']) + self._metric = Auprc() + + super().__init__(root_dir, download, split_scheme) + + def get_input(self, idx): + """ + Returns x for a given idx. + Computes this from: + (1) sequence features in self._seq_bp + (2) DNase features in self._dnase_allcelltypes + (3) Metadata for the index (location along the genome with 1kb window width) + """ + this_metadata = self._metadata_df.iloc[idx, :] + flank_size = 500 + interval_start = this_metadata['start'] - flank_size + interval_end = this_metadata['stop'] + flank_size + dnase_this = _dnase_allcelltypes[this_metadata['celltype']][this_metadata['chr']][interval_start:interval_end] + seq_this = _seq_bp[this_metadata['chr']][interval_start:interval_end] + return np.column_stack([seq_this, dnase_this]) + + def eval(self, y_pred, y_true, metadata): + return standard_group_eval( + self._metric, + self._eval_grouper, + y_pred, y_true, metadata) From 4975b8e00b42f488daf0fca17d69dc38632b8f94 Mon Sep 17 00:00:00 2001 From: aikanor Date: Tue, 2 Feb 2021 10:56:31 -0800 Subject: [PATCH 002/244] model/dataset fetching in nb 1/ --- .../sandbox_data-checkpoint.ipynb | 952 ++++++++++++++++++ .../encode-tfbs/sandbox_data.ipynb | 952 ++++++++++++++++++ sandbox_model.ipynb | 876 ++++++++++++++++ 3 files changed, 2780 insertions(+) create mode 100644 dataset_preprocessing/encode-tfbs/.ipynb_checkpoints/sandbox_data-checkpoint.ipynb create mode 100644 dataset_preprocessing/encode-tfbs/sandbox_data.ipynb create mode 100644 sandbox_model.ipynb diff --git a/dataset_preprocessing/encode-tfbs/.ipynb_checkpoints/sandbox_data-checkpoint.ipynb b/dataset_preprocessing/encode-tfbs/.ipynb_checkpoints/sandbox_data-checkpoint.ipynb new file mode 100644 index 00000000..b2e74829 --- /dev/null +++ b/dataset_preprocessing/encode-tfbs/.ipynb_checkpoints/sandbox_data-checkpoint.ipynb @@ -0,0 +1,952 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Initialize dataset object" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "57.5289368629\n", + "65.2459537983\n" + ] + } + ], + "source": [ + "import numpy as np, pandas as pd, os, time\n", + "import torch, torchvision\n", + "\n", + "data_dir = '/oak/stanford/groups/akundaje/abalsubr/DREAM/wilds/codalab_archive/'\n", + "tf = 'MAX'\n", + "itime = time.time()\n", + "train_chr = pd.read_csv(os.path.join(data_dir, 'labels/{}.train.labels.tsv.gz'.format(tf)), sep='\\t')\n", + "print(time.time() - itime)\n", + "val_chr = pd.read_csv(os.path.join(data_dir, 'labels/{}.val.labels.tsv.gz'.format(tf)), sep='\\t')\n", + "print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']\n", + "val_celltype = ['A549']\n", + "test_celltype = ['GM12878']\n", + "all_celltypes = train_celltypes + val_celltype + test_celltype\n", + "\n", + "metadata_map = {}\n", + "metadata_map['chr'] = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX']\n", + "metadata_map['celltype'] = all_celltypes\n", + "\n", + "_split_dict = {\n", + " 'train': 0,\n", + " 'val-id': 1,\n", + " 'test': 2,\n", + " 'val-ood': 3\n", + "}\n", + "_split_names = {\n", + " 'train': 'Train',\n", + " 'val-id': 'Validation (ID)',\n", + " 'test': 'Test',\n", + " 'val-ood': 'Validation (OOD)',\n", + "}\n", + "_split_scheme = 'standard'" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.0467748641968\n", + "('chr1', 4.52302885055542)\n", + "('chr2', 8.645489931106567)\n", + "('chr3', 11.959153890609741)\n", + "('chr4', 15.15813684463501)\n", + "('chr5', 18.22238802909851)\n", + "('chr6', 21.19420099258423)\n", + "('chr7', 23.940655946731567)\n", + "('chr8', 26.415233850479126)\n", + "('chr9', 28.833614826202393)\n", + "('chr10', 31.08920383453369)\n", + "('chr11', 33.37020301818848)\n", + "('chr12', 35.98973989486694)\n", + "('chr13', 37.88540601730347)\n", + "('chr14', 39.68082284927368)\n", + "('chr15', 41.242313861846924)\n", + "('chr16', 42.74874496459961)\n", + "('chr17', 44.12280797958374)\n", + "('chr18', 45.46893382072449)\n", + "('chr19', 46.50577902793884)\n", + "('chr20', 47.59563183784485)\n", + "('chr21', 48.31779384613037)\n", + "('chr22', 49.17265295982361)\n", + "('chrX', 51.75806999206543)\n", + "('H1-hESC', 25.880441904067993)\n", + "('HCT116', 50.130937814712524)\n", + "('HeLa-S3', 75.29559993743896)\n", + "('HepG2', 102.25979495048523)\n", + "('K562', 128.43050694465637)\n", + "('A549', 154.80679488182068)\n", + "('GM12878', 182.0279529094696)\n" + ] + } + ], + "source": [ + "itime = time.time()\n", + "sequence_filename = os.path.join(data_dir, 'sequence.npz')\n", + "seq_arr = np.load(sequence_filename)\n", + "print(time.time() - itime)\n", + "\n", + "itime = time.time()\n", + "_seq_bp = {}\n", + "for chrom in seq_arr:\n", + " _seq_bp[chrom] = seq_arr[chrom]\n", + " print(chrom, time.time() - itime)\n", + "\n", + "itime = time.time()\n", + "_dnase_allcelltypes = {}\n", + "for ct in all_celltypes:\n", + " dnase_filename = os.path.join(data_dir, '{}_dnase.npz'.format(ct))\n", + " dnase_npz_file = np.load(dnase_filename)\n", + " _dnase_allcelltypes[ct] = {}\n", + " for chrom in _seq_bp:\n", + " _dnase_allcelltypes[ct][chrom] = dnase_npz_file[chrom]\n", + " print(ct, time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'all_df' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# len(_dnase_allcelltypes)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mall_df\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mNameError\u001b[0m: name 'all_df' is not defined" + ] + } + ], + "source": [ + "# len(_dnase_allcelltypes)\n", + "all_df" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "ename": "AttributeError", + "evalue": "'module' object has no attribute 'isin'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mtr_chrs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m'chr2'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr9'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr11'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mte_chrs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m'chr1'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr8'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr21'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mtraining_df\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'chr'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtr_chrs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mval_df\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mval_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mval_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'chr'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mte_chrs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mall_df\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtraining_df\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_df\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mAttributeError\u001b[0m: 'module' object has no attribute 'isin'" + ] + } + ], + "source": [ + "tr_chrs = ['chr2', 'chr9', 'chr11']\n", + "te_chrs = ['chr1', 'chr8', 'chr21']\n", + "training_df = train_chr[np.isin(train_chr['chr'], tr_chrs)]\n", + "val_df = val_chr[np.isin(val_chr['chr'], te_chrs)]\n", + "all_df = pd.concat([training_df, val_df])\n", + "\n", + "#filter_msk = all_df['start'] >= 0\n", + "filter_msk = all_df['start']%1000 == 0\n", + "all_df = all_df[filter_msk]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "itime = time.time()\n", + "pd_list = []\n", + "for ct in all_celltypes:\n", + " tc_chr = all_df[['chr', 'start', 'stop', ct]]\n", + " tc_chr.columns = ['chr', 'start', 'stop', 'y']\n", + " tc_chr['celltype'] = ct\n", + " pd_list.append(tc_chr)\n", + "metadata_df = pd.concat(pd_list)\n", + "print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "itime = time.time()\n", + "y_array = metadata_df['y'].replace({'U': 0, 'B': 1, 'A': -1}).values\n", + "non_ambig_mask = (y_array != -1)\n", + "metadata_df['y'] = y_array\n", + "_metadata_df = metadata_df[non_ambig_mask]\n", + "_y_array = torch.LongTensor(y_array[non_ambig_mask])\n", + "print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "itime = time.time()\n", + "chr_ints = _metadata_df['chr'].replace(dict( [(y, x) for x, y in enumerate(metadata_map['chr'])] )).values\n", + "celltype_ints = _metadata_df['celltype'].replace(dict( [(y, x) for x, y in enumerate(metadata_map['celltype'])] )).values\n", + "print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_chr_mask = np.isin(_metadata_df['chr'], tr_chrs)\n", + "val_chr_mask = np.isin(_metadata_df['chr'], te_chrs)\n", + "train_celltype_mask = np.isin(_metadata_df['celltype'], train_celltypes)\n", + "val_celltype_mask = np.isin(_metadata_df['celltype'], val_celltype)\n", + "test_celltype_mask = np.isin(_metadata_df['celltype'], test_celltype)\n", + "\n", + "split_array = -1*np.ones(_metadata_df.shape[0]).astype(int)\n", + "split_array[np.logical_and(train_chr_mask, train_celltype_mask)] = _split_dict['train']\n", + "split_array[np.logical_and(val_chr_mask, test_celltype_mask)] = _split_dict['test']\n", + "split_array[np.logical_and(val_chr_mask, val_celltype_mask)] = _split_dict['val-ood']\n", + "split_array[np.logical_and(val_chr_mask, train_celltype_mask)] = _split_dict['val-id']\n", + "_metadata_df['split'] = split_array\n", + "_split_array = split_array" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "ename": "ImportError", + "evalue": "No module named data", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mDataLoader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mdataset_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mImportError\u001b[0m: No module named data" + ] + } + ], + "source": [ + "from torch.utils.data import DataLoader\n", + "from data import dataset_attributes" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "from PIL import Image\n", + "import argparse\n", + "class ParseKwargs(argparse.Action):\n", + " def __call__(self, parser, namespace, values, option_string=None):\n", + " setattr(namespace, self.dest, dict())\n", + " for value in values:\n", + " key, value_str = value.split('=')\n", + " if value_str.replace('-','').isnumeric():\n", + " processed_val = int(value_str)\n", + " elif value_str.replace('-','').replace('.','').isnumeric():\n", + " processed_val = float(value_str)\n", + " elif value_str in ['True', 'true']:\n", + " processed_val = True\n", + " elif value_str in ['False', 'false']:\n", + " processed_val = False\n", + " else:\n", + " processed_val = value_str\n", + " getattr(namespace, self.dest)[key] = processed_val" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'algorithm_constructors' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[0;31m# Algorithm and objective\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 33\u001b[0;31m \u001b[0mparser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_argument\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'--algorithm'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrequired\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mchoices\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0malgorithm_constructors\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 34\u001b[0m \u001b[0mparser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_argument\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'--algorithm_kwargs'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnargs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'*'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maction\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mParseKwargs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdefault\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 35\u001b[0m \u001b[0mparser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_argument\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'--groupby_fields'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnargs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'+'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdefault\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mNameError\u001b[0m: name 'algorithm_constructors' is not defined" + ] + } + ], + "source": [ + "ROOTDIR = '/oak/stanford/groups/akundaje/abalsubr/wilds_other'\n", + "args_kw = \"-d camelyon17 --algorithm ERM --model densenet121 --split_scheme standard --groupby_fields hospital --loss_function cross_entropy --optimizer SGD --lr 0.0001 --batch_size 32 --weight_decay 0 --n_epochs 10 --scheduler ReduceLROnPlateau --scheduler_metric_split val --scheduler_metric_name acc_avg --log_dir log --log_every 50 --save_step 1000 --save_best --save_last --seed 0 --evaluate_all_splits --root_dir {}\".format(\n", + " ROOTDIR).split()\n", + "\n", + "parser = argparse.ArgumentParser()\n", + "\n", + "# Dataset\n", + "parser.add_argument('-d', '--dataset', choices=['encodeTFBS', 'amazon', 'camelyon17', 'celebA', 'civilcomments', 'iwildcam', 'waterbirds', 'yelp', 'poverty', 'fmow', 'ogbg-molpcba'], required=True)\n", + "parser.add_argument('--split_scheme', default='standard',\n", + " help='Identifies how the train/val/test split is constructed. Choices are dataset-specific.')\n", + "parser.add_argument('--dataset_kwargs', nargs='*', action=ParseKwargs, default={})\n", + "parser.add_argument('--root_dir', default=None, required=True,\n", + " help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).')\n", + "parser.add_argument('--download', default=False, action='store_true',\n", + " help='If true, tries to downloads the dataset if it does not exist in root_dir.')\n", + "parser.add_argument('--frac', type=float, default=1.0,\n", + " help='Convenience parameter that scales all dataset splits down to the specified fraction, for development purposes.')\n", + "\n", + "# Loaders\n", + "parser.add_argument('--train_loader', choices=['standard', 'group'], default='standard')\n", + "parser.add_argument('--train_loader_kwargs', nargs='*', action=ParseKwargs, default={})\n", + "parser.add_argument('--eval_loader', choices=['standard', 'group'], default='standard')\n", + "parser.add_argument('--batch_size', type=int, default=32)\n", + "parser.add_argument('--no_pin_memory', action='store_true') # TODO: put as loader_kwargs\n", + "parser.add_argument('--num_workers', type=int, default=4) # TODO: put as loader kwargs\n", + "\n", + "# Model\n", + "parser.add_argument(\n", + " '--model',\n", + " choices=['bert-base-uncased', 'inception_v3', 'densenet121', 'wideresnet50', 'resnet50', 'gin-virtual', 'resnet18_ms'],\n", + " default='resnet50')\n", + "parser.add_argument('--model_kwargs', nargs='*', action=ParseKwargs, default={},\n", + " help='keyword arguments for model initialization passed as key1=value1 key2=value2')\n", + "parser.add_argument('--train_from_scratch', action='store_true', default=False)\n", + "\n", + "# Algorithm and objective\n", + "parser.add_argument('--algorithm', required=True, choices=algorithm_constructors.keys())\n", + "parser.add_argument('--algorithm_kwargs', nargs='*', action=ParseKwargs, default={})\n", + "parser.add_argument('--groupby_fields', nargs='+', default=None)\n", + "parser.add_argument('--loss_function', required=True, choices = losses.keys()) #TODO: make default\n", + "parser.add_argument('--val_metric', default=None)\n", + "\n", + "# Optimization\n", + "parser.add_argument('--n_epochs', type=int, default=4)\n", + "parser.add_argument('--optimizer', default=None, choices=optimizer_attributes.keys())\n", + "parser.add_argument('--lr', type=float, required=True)\n", + "parser.add_argument('--weight_decay', type=float, required=True)\n", + "parser.add_argument('--optimizer_kwargs', nargs='*', action=ParseKwargs, default={})\n", + "parser.add_argument('--scheduler', default=None, choices=scheduler_attributes.keys())\n", + "parser.add_argument('--scheduler_kwargs', nargs='*', action=ParseKwargs, default={})\n", + "parser.add_argument('--scheduler_metric_split', choices=['train', 'val'], default='val')\n", + "parser.add_argument('--scheduler_metric_name')\n", + "\n", + "# Evaluation\n", + "parser.add_argument('--evaluate_all_splits', action='store_true', default=False)\n", + "parser.add_argument('--additional_eval_splits', nargs='+', default=[])\n", + "\n", + "# Misc\n", + "parser.add_argument('--device', type=int, default=0)\n", + "parser.add_argument('--seed', type=int, default=0)\n", + "parser.add_argument('--log_dir', default='./logs')\n", + "parser.add_argument('--log_every', default=50, type=int)\n", + "parser.add_argument('--save_step', type=int, default=None)\n", + "parser.add_argument('--save_best', action='store_true', default=False)\n", + "parser.add_argument('--save_last', action='store_true', default=False)\n", + "parser.add_argument('--save_outputs', action='store_true', default=False)\n", + "parser.add_argument('--no_group_logging', action='store_true', default=False)\n", + "parser.add_argument('--val_metric_decreasing', action='store_true', default=False)\n", + "parser.add_argument('--use_wandb', action='store_true', default=False)\n", + "parser.add_argument('--progress_bar', action='store_true', default=False)\n", + "parser.add_argument('--resume', default=False, action='store_true')\n", + "parser.add_argument('--eval_only', default=False, action='store_true')\n", + "\n", + "args = parser.parse_args(args_kw)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# get_input (idx)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name '_metadata_df' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0midx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mthis_metadata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_metadata_df\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miloc\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mitime\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mflank_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m400\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mNameError\u001b[0m: name '_metadata_df' is not defined" + ] + } + ], + "source": [ + "idx = 3\n", + "this_metadata = _metadata_df.iloc[idx, :]\n", + "\n", + "itime = time.time()\n", + "flank_size = 400\n", + "interval_start = this_metadata['start'] - flank_size\n", + "interval_end = this_metadata['stop'] + flank_size\n", + "dnase_this = _dnase_allcelltypes[this_metadata['celltype']][this_metadata['chr']][interval_start:interval_end]\n", + "seq_this = _seq_bp[this_metadata['chr']][interval_start:interval_end]\n", + "data = np.column_stack([seq_this, dnase_this])\n", + "# print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.028102874755859375\n" + ] + } + ], + "source": [ + "itime = time.time()\n", + "metadata_array = torch.stack(\n", + " (torch.LongTensor(chr_ints), \n", + " torch.LongTensor(celltype_ints), \n", + " _y_array),\n", + " dim=1)\n", + "print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [ + { + "ename": "ModuleNotFoundError", + "evalue": "No module named 'torch_scatter'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m#data.shape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mwilds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata_loaders\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mget_train_loader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mget_eval_loader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m~/dr_benchmark/wilds/common/data_loaders.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mDataLoader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msampler\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mWeightedRandomSampler\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mSubsetRandomSampler\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mwilds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mget_counts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msplit_into_groups\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mIPython\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/dr_benchmark/wilds/common/utils.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch_scatter\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnumpy\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mSubset\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mpandas\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapi\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtypes\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mCategoricalDtype\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'torch_scatter'" + ] + } + ], + "source": [ + "#data.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 157, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "4600" + ] + }, + "execution_count": 157, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data.shape\n", + "interval_end\n", + "# itime = time.time()\n", + "# np.save(os.path.join(data_dir, 'stmp.npy'), sa)\n", + "# print(time.time() - itime)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Run training experiment" + ] + }, + { + "cell_type": "code", + "execution_count": 167, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'python3 examples/run_expt.py -d encodeTFBS --algorithm ERM --model densenet121 --split_scheme standard --groupby_fields hospital --loss_function cross_entropy --optimizer SGD --lr 0.0001 --batch_size 32 --weight_decay 0 --n_epochs 10 --scheduler ReduceLROnPlateau --scheduler_metric_split val --scheduler_metric_name acc_avg --log_dir log --log_every 50 --save_step 1000 --save_best --save_last --seed 0 --evaluate_all_splits --root_dir ROOTDIR'" + ] + }, + "execution_count": 167, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cmdstr = \"python3 examples/run_expt.py -d encodeTFBS --algorithm ERM --model densenet121 --split_scheme standard --groupby_fields hospital --loss_function cross_entropy\"\n", + "cmdstr += \" \"\n", + "cmdstr += \"--optimizer SGD --lr 0.0001 --batch_size 32 --weight_decay 0 --n_epochs 10 --scheduler ReduceLROnPlateau --scheduler_metric_split val --scheduler_metric_name acc_avg\"\n", + "cmdstr += \" \"\n", + "cmdstr += \"--log_dir log --log_every 50 --save_step 1000 --save_best --save_last --seed 0 --evaluate_all_splits --root_dir ROOTDIR\"\n", + "cmdstr" + ] + }, + { + "cell_type": "code", + "execution_count": 164, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name '_metadata_array' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0m_metadata_array\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mNameError\u001b[0m: name '_metadata_array' is not defined" + ] + } + ], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 165, + "metadata": {}, + "outputs": [ + { + "ename": "ModuleNotFoundError", + "evalue": "No module named 'torch_scatter'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minsert\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'..'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 16\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mwilds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata_loaders\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mget_train_loader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mget_eval_loader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 17\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mwilds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrouper\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mCombinatorialGrouper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mwilds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mget_counts\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/dr_benchmark/wilds/common/data_loaders.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mDataLoader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msampler\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mWeightedRandomSampler\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mSubsetRandomSampler\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mwilds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mget_counts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msplit_into_groups\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mIPython\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/dr_benchmark/wilds/common/utils.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch_scatter\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnumpy\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mSubset\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mpandas\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapi\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtypes\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mCategoricalDtype\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'torch_scatter'" + ] + } + ], + "source": [ + "import os, csv\n", + "import time\n", + "import argparse\n", + "import IPython\n", + "import pandas as pd\n", + "import torch\n", + "import torch.nn as nn\n", + "import torchvision\n", + "import sys\n", + "from collections import defaultdict\n", + "# torch.multiprocessing.set_sharing_strategy('file_system')\n", + "\n", + "# TODO: Replace this once we make wilds into an installed package\n", + "sys.path.insert(1, os.path.join(sys.path[0], '..'))\n", + "\n", + "from wilds.common.data_loaders import get_train_loader, get_eval_loader\n", + "from wilds.common.grouper import CombinatorialGrouper\n", + "from wilds.common.utils import get_counts\n", + "\n", + "from models.model_attributes import model_attributes\n", + "from utils import set_seed, Logger, BatchLogger, log_args, ParseKwargs, load\n", + "from train import train, evaluate\n", + "from data import dataset_attributes\n", + "from optimizer import optimizer_attributes\n", + "from scheduler import scheduler_attributes\n", + "from loss import losses\n", + "from utils import log_group_data\n", + "from algorithms.constructors import algorithm_constructors" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from examples.models.model_attributes import model_attributes" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "ename": "ModuleNotFoundError", + "evalue": "No module named 'utils'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel_attributes\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mmodel_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mset_seed\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mLogger\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mCSVBatchLogger\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlog_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mParseKwargs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mload\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 21\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 22\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mdataset_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0moptimizer_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/dr_benchmark/examples/train.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0msave\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mIPython\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'utils'" + ] + } + ], + "source": [ + "def initialize_algorithm(args, datasets, train_grouper):\n", + " train_dataset = datasets['train']['dataset']\n", + " train_loader = datasets['train']['loader']\n", + "\n", + " # Configure the final layer of the networks used\n", + " # The code below are defaults. Edit this if you need special config for your model.\n", + " if (train_dataset.is_classification) and (train_dataset.y_size == 1):\n", + " # For single-task classification, we have one output per class\n", + " d_out = train_dataset.n_classes\n", + " elif (train_dataset.is_classification) and (train_dataset.y_size > 1) and (train_dataset.n_classes == 2):\n", + " # For multi-task binary classification (each output is the logit for each binary class)\n", + " d_out = train_dataset.y_size\n", + " elif (not train_dataset.is_classification):\n", + " # For regression, we have one output per target dimension\n", + " d_out = train_dataset.y_size\n", + " else:\n", + " raise RuntimeError('d_out not defined.')\n", + " \n", + "\n", + " # Sanity checking input args\n", + " if args.algorithm == 'groupDRO':\n", + " assert args.train_loader_kwargs['uniform_over_groups']\n", + " elif args.algorithm in ['deepCORAL', 'IRM']:\n", + " assert args.train_loader == 'group'\n", + " assert args.train_loader_kwargs['uniform_over_groups']\n", + " assert args.train_loader_kwargs['distinct_groups']\n", + "\n", + " # Other config\n", + " n_train_steps = len(train_loader) * args.n_epochs\n", + "# prediction_fn = dataset_attributes[args.dataset]['prediction_fn']\n", + " loss = losses[args.loss_function]\n", + " metric = dataset_attributes[args.dataset]['metric']\n", + " train_g = train_grouper.metadata_to_group(train_dataset.metadata_array)\n", + " is_group_in_train = get_counts(train_g, train_grouper.n_groups) > 0\n", + " algorithm_constructor = algorithm_constructors[args.algorithm]\n", + " algorithm = algorithm_constructor(\n", + " args=args,\n", + " d_out=d_out,\n", + " grouper=train_grouper,\n", + " loss=loss,\n", + " metric=metric,\n", + " n_train_steps=n_train_steps,\n", + " is_group_in_train=is_group_in_train)\n", + " return algorithm" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def main():\n", + " parser = argparse.ArgumentParser()\n", + "\n", + " # Dataset\n", + " parser.add_argument('-d', '--dataset', choices=dataset_attributes.keys(), required=True)\n", + " parser.add_argument('--split_scheme', default='standard',\n", + " help='Identifies how the train/val/test split is constructed. Choices are dataset-specific.')\n", + " parser.add_argument('--dataset_kwargs', nargs='*', action=ParseKwargs, default={})\n", + " parser.add_argument('--root_dir', default=None, required=True,\n", + " help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).')\n", + " parser.add_argument('--download', default=False, action='store_true',\n", + " help='If true, tries to downloads the dataset if it does not exist in root_dir.')\n", + " parser.add_argument('--frac', type=float, default=1.0,\n", + " help='Convenience parameter that scales all dataset splits down to the specified fraction, for development purposes.')\n", + "\n", + " # Loaders\n", + " parser.add_argument('--train_loader', choices=['standard', 'group'], default='standard')\n", + " parser.add_argument('--train_loader_kwargs', nargs='*', action=ParseKwargs, default={})\n", + " parser.add_argument('--eval_loader', choices=['standard', 'group'], default='standard')\n", + " parser.add_argument('--batch_size', type=int, default=32)\n", + " parser.add_argument('--no_pin_memory', action='store_true') # TODO: put as loader_kwargs\n", + " parser.add_argument('--num_workers', type=int, default=4) # TODO: put as loader kwargs\n", + "\n", + " # Model\n", + " parser.add_argument(\n", + " '--model',\n", + " choices=model_attributes.keys(),\n", + " default='resnet50')\n", + " parser.add_argument('--model_kwargs', nargs='*', action=ParseKwargs, default={},\n", + " help='keyword arguments for model initialization passed as key1=value1 key2=value2')\n", + " parser.add_argument('--train_from_scratch', action='store_true', default=False)\n", + "\n", + " # Algorithm and objective\n", + " parser.add_argument('--algorithm', required=True, choices=algorithm_constructors.keys())\n", + " parser.add_argument('--algorithm_kwargs', nargs='*', action=ParseKwargs, default={})\n", + " parser.add_argument('--groupby_fields', nargs='+', default=None)\n", + " parser.add_argument('--loss_function', required=True, choices = losses.keys()) #TODO: make default\n", + " parser.add_argument('--val_metric', default=None)\n", + "\n", + " # Optimization\n", + " parser.add_argument('--n_epochs', type=int, default=4)\n", + " parser.add_argument('--optimizer', default=None, choices=optimizer_attributes.keys())\n", + " parser.add_argument('--lr', type=float, required=True)\n", + " parser.add_argument('--weight_decay', type=float, required=True)\n", + " parser.add_argument('--optimizer_kwargs', nargs='*', action=ParseKwargs, default={})\n", + " parser.add_argument('--scheduler', default=None, choices=scheduler_attributes.keys())\n", + " parser.add_argument('--scheduler_kwargs', nargs='*', action=ParseKwargs, default={})\n", + " parser.add_argument('--scheduler_metric_split', choices=['train', 'val'], default='val')\n", + " parser.add_argument('--scheduler_metric_name')\n", + "\n", + " # Evaluation\n", + " parser.add_argument('--evaluate_all_splits', action='store_true', default=False)\n", + " parser.add_argument('--additional_eval_splits', nargs='+', default=[])\n", + "\n", + " # Misc\n", + " parser.add_argument('--device', type=int, default=0)\n", + " parser.add_argument('--seed', type=int, default=0)\n", + " parser.add_argument('--log_dir', default='./logs')\n", + " parser.add_argument('--log_every', default=50, type=int)\n", + " parser.add_argument('--save_step', type=int, default=None)\n", + " parser.add_argument('--save_best', action='store_true', default=False)\n", + " parser.add_argument('--save_last', action='store_true', default=False)\n", + " parser.add_argument('--save_outputs', action='store_true', default=False)\n", + " parser.add_argument('--no_group_logging', action='store_true', default=False)\n", + " parser.add_argument('--val_metric_decreasing', action='store_true', default=False)\n", + " parser.add_argument('--use_wandb', action='store_true', default=False)\n", + " parser.add_argument('--progress_bar', action='store_true', default=False)\n", + " parser.add_argument('--resume', default=False, action='store_true')\n", + " parser.add_argument('--eval_only', default=False, action='store_true')\n", + "\n", + " args = parser.parse_args()\n", + "\n", + " # set device\n", + " args.device = torch.device(\"cuda:\" + str(args.device)) if torch.cuda.is_available() else torch.device(\"cpu\")\n", + "\n", + " # Set defaults\n", + " if args.groupby_fields is None:\n", + " args.no_group_logging = True\n", + " if args.val_metric is None:\n", + " args.val_metric = dataset_attributes[args.dataset]['val_metric']\n", + "\n", + " ## Initialize logs\n", + " if os.path.exists(args.log_dir) and args.resume:\n", + " resume=True\n", + " mode='a'\n", + " else:\n", + " resume=False\n", + " mode='w'\n", + " if not os.path.exists(args.log_dir):\n", + " os.makedirs(args.log_dir)\n", + " logger = Logger(os.path.join(args.log_dir, 'log.txt'), mode)\n", + "\n", + " # Record args\n", + " log_args(args, logger)\n", + "\n", + " # Set random seed\n", + " set_seed(args.seed)\n", + "\n", + " # Data\n", + " full_dataset = dataset_attributes[args.dataset]['constructor'](\n", + " root_dir=args.root_dir,\n", + " download=args.download,\n", + " split_scheme=args.split_scheme,\n", + " **args.dataset_kwargs)\n", + "\n", + " # To implement data augmentation (i.e., have different transforms\n", + " # at training time vs. test time), modify these two lines:\n", + " train_transform = dataset_attributes[args.dataset]['transform'](args.model)\n", + " if dataset_attributes[args.dataset].get('eval_transform') is None:\n", + " eval_transform = dataset_attributes[args.dataset]['transform'](args.model)\n", + " else:\n", + " eval_transform = dataset_attributes[args.dataset]['eval_transform'](args.model)\n", + "\n", + " train_grouper = CombinatorialGrouper(\n", + " dataset=full_dataset,\n", + " groupby_fields=args.groupby_fields)\n", + "\n", + " datasets = defaultdict(dict)\n", + " for split in full_dataset.split_dict.keys():\n", + " if split=='train':\n", + " transform = train_transform\n", + " verbose = True\n", + " elif split == 'val':\n", + " transform = eval_transform\n", + " verbose = True\n", + " else:\n", + " transform = eval_transform\n", + " verbose = False\n", + " # Get subset\n", + " datasets[split]['dataset'] = full_dataset.get_subset(\n", + " split,\n", + " frac=args.frac,\n", + " transform=transform)\n", + "\n", + " # Get loader\n", + " shared_loader_kwargs = {\n", + " 'num_workers': args.num_workers,\n", + " 'pin_memory': not args.no_pin_memory,\n", + " 'batch_size': args.batch_size,\n", + " 'collate_fn': dataset_attributes[args.dataset]['collate']\n", + " }\n", + "\n", + " if split == 'train':\n", + " datasets[split]['loader'] = get_train_loader(\n", + " loader=args.train_loader,\n", + " dataset=datasets[split]['dataset'],\n", + " grouper=train_grouper,\n", + " train_loader_kwargs=args.train_loader_kwargs,\n", + " **shared_loader_kwargs)\n", + " else:\n", + " datasets[split]['loader'] = get_eval_loader(\n", + " loader=args.eval_loader,\n", + " dataset=datasets[split]['dataset'],\n", + " grouper=train_grouper,\n", + " **shared_loader_kwargs)\n", + "\n", + " # Set fields\n", + " datasets[split]['split'] = split\n", + " datasets[split]['name'] = full_dataset.split_names[split]\n", + " datasets[split]['verbose'] = verbose\n", + " # Loggers\n", + " # Loggers\n", + " datasets[split]['eval_logger'] = BatchLogger(\n", + " os.path.join(args.log_dir, f'{split}_eval.csv'), mode=mode, use_wandb=args.use_wandb)\n", + " datasets[split]['algo_logger'] = BatchLogger(\n", + " os.path.join(args.log_dir, f'{split}_algo.csv'), mode=mode, use_wandb=args.use_wandb)\n", + "\n", + " if args.use_wandb:\n", + " initialize_wandb(args)\n", + "\n", + " # Logging dataset info\n", + " if args.no_group_logging and full_dataset.is_classification and full_dataset.y_size==1:\n", + " log_grouper = CombinatorialGrouper(\n", + " dataset=full_dataset,\n", + " groupby_fields=['y'])\n", + " elif args.no_group_logging:\n", + " log_grouper = None\n", + " else:\n", + " log_grouper = train_grouper\n", + " log_group_data(args, datasets, log_grouper, logger)\n", + "\n", + " ## Initialize algorithm\n", + " algorithm = initialize_algorithm(args, datasets, train_grouper)\n", + "\n", + " if not args.eval_only:\n", + " ## Load saved results if resuming\n", + " resume_success = False\n", + " if resume:\n", + " save_path = os.path.join(args.log_dir, 'last_model.pth')\n", + " if not os.path.exists(save_path):\n", + " epochs = [\n", + " int(file.split('_')[0])\n", + " for file in os.listdir(args.log_dir) if file.endswith('.pth')]\n", + " if len(epochs) > 0:\n", + " latest_epoch = max(epochs)\n", + " save_path = os.path.join(args.log_dir, f'{latest_epoch}_model.pth')\n", + " try:\n", + " prev_epoch, best_val_metric = load(algorithm, save_path)\n", + " epoch_offset = prev_epoch + 1\n", + " logger.write(f'Resuming from epoch {epoch_offset} with best val metric {best_val_metric}')\n", + " resume_success = True\n", + " except FileNotFoundError:\n", + " pass\n", + "\n", + " if resume_success == False:\n", + " epoch_offset=0\n", + " best_val_metric=None\n", + "\n", + "\n", + " train(algorithm,\n", + " datasets,\n", + " logger,\n", + " args,\n", + " epoch_offset=epoch_offset,\n", + " best_val_metric=best_val_metric)\n", + " else:\n", + " best_model_path = os.path.join(args.log_dir, 'best_model.pth')\n", + " best_epoch, best_val_metric = load(algorithm, best_model_path)\n", + " evaluate(algorithm, datasets, best_epoch, logger)\n", + "\n", + " logger.close()\n", + " for split in datasets:\n", + " datasets[split]['eval_logger'].close()\n", + " datasets[split]['algo_logger'].close()\n", + "\n", + "if __name__=='__main__':\n", + " main()\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 2", + "language": "python", + "name": "python2" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/dataset_preprocessing/encode-tfbs/sandbox_data.ipynb b/dataset_preprocessing/encode-tfbs/sandbox_data.ipynb new file mode 100644 index 00000000..b2e74829 --- /dev/null +++ b/dataset_preprocessing/encode-tfbs/sandbox_data.ipynb @@ -0,0 +1,952 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Initialize dataset object" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "57.5289368629\n", + "65.2459537983\n" + ] + } + ], + "source": [ + "import numpy as np, pandas as pd, os, time\n", + "import torch, torchvision\n", + "\n", + "data_dir = '/oak/stanford/groups/akundaje/abalsubr/DREAM/wilds/codalab_archive/'\n", + "tf = 'MAX'\n", + "itime = time.time()\n", + "train_chr = pd.read_csv(os.path.join(data_dir, 'labels/{}.train.labels.tsv.gz'.format(tf)), sep='\\t')\n", + "print(time.time() - itime)\n", + "val_chr = pd.read_csv(os.path.join(data_dir, 'labels/{}.val.labels.tsv.gz'.format(tf)), sep='\\t')\n", + "print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']\n", + "val_celltype = ['A549']\n", + "test_celltype = ['GM12878']\n", + "all_celltypes = train_celltypes + val_celltype + test_celltype\n", + "\n", + "metadata_map = {}\n", + "metadata_map['chr'] = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX']\n", + "metadata_map['celltype'] = all_celltypes\n", + "\n", + "_split_dict = {\n", + " 'train': 0,\n", + " 'val-id': 1,\n", + " 'test': 2,\n", + " 'val-ood': 3\n", + "}\n", + "_split_names = {\n", + " 'train': 'Train',\n", + " 'val-id': 'Validation (ID)',\n", + " 'test': 'Test',\n", + " 'val-ood': 'Validation (OOD)',\n", + "}\n", + "_split_scheme = 'standard'" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.0467748641968\n", + "('chr1', 4.52302885055542)\n", + "('chr2', 8.645489931106567)\n", + "('chr3', 11.959153890609741)\n", + "('chr4', 15.15813684463501)\n", + "('chr5', 18.22238802909851)\n", + "('chr6', 21.19420099258423)\n", + "('chr7', 23.940655946731567)\n", + "('chr8', 26.415233850479126)\n", + "('chr9', 28.833614826202393)\n", + "('chr10', 31.08920383453369)\n", + "('chr11', 33.37020301818848)\n", + "('chr12', 35.98973989486694)\n", + "('chr13', 37.88540601730347)\n", + "('chr14', 39.68082284927368)\n", + "('chr15', 41.242313861846924)\n", + "('chr16', 42.74874496459961)\n", + "('chr17', 44.12280797958374)\n", + "('chr18', 45.46893382072449)\n", + "('chr19', 46.50577902793884)\n", + "('chr20', 47.59563183784485)\n", + "('chr21', 48.31779384613037)\n", + "('chr22', 49.17265295982361)\n", + "('chrX', 51.75806999206543)\n", + "('H1-hESC', 25.880441904067993)\n", + "('HCT116', 50.130937814712524)\n", + "('HeLa-S3', 75.29559993743896)\n", + "('HepG2', 102.25979495048523)\n", + "('K562', 128.43050694465637)\n", + "('A549', 154.80679488182068)\n", + "('GM12878', 182.0279529094696)\n" + ] + } + ], + "source": [ + "itime = time.time()\n", + "sequence_filename = os.path.join(data_dir, 'sequence.npz')\n", + "seq_arr = np.load(sequence_filename)\n", + "print(time.time() - itime)\n", + "\n", + "itime = time.time()\n", + "_seq_bp = {}\n", + "for chrom in seq_arr:\n", + " _seq_bp[chrom] = seq_arr[chrom]\n", + " print(chrom, time.time() - itime)\n", + "\n", + "itime = time.time()\n", + "_dnase_allcelltypes = {}\n", + "for ct in all_celltypes:\n", + " dnase_filename = os.path.join(data_dir, '{}_dnase.npz'.format(ct))\n", + " dnase_npz_file = np.load(dnase_filename)\n", + " _dnase_allcelltypes[ct] = {}\n", + " for chrom in _seq_bp:\n", + " _dnase_allcelltypes[ct][chrom] = dnase_npz_file[chrom]\n", + " print(ct, time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'all_df' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# len(_dnase_allcelltypes)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mall_df\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mNameError\u001b[0m: name 'all_df' is not defined" + ] + } + ], + "source": [ + "# len(_dnase_allcelltypes)\n", + "all_df" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "ename": "AttributeError", + "evalue": "'module' object has no attribute 'isin'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mtr_chrs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m'chr2'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr9'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr11'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mte_chrs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m'chr1'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr8'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr21'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mtraining_df\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'chr'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtr_chrs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mval_df\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mval_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mval_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'chr'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mte_chrs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mall_df\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtraining_df\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_df\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mAttributeError\u001b[0m: 'module' object has no attribute 'isin'" + ] + } + ], + "source": [ + "tr_chrs = ['chr2', 'chr9', 'chr11']\n", + "te_chrs = ['chr1', 'chr8', 'chr21']\n", + "training_df = train_chr[np.isin(train_chr['chr'], tr_chrs)]\n", + "val_df = val_chr[np.isin(val_chr['chr'], te_chrs)]\n", + "all_df = pd.concat([training_df, val_df])\n", + "\n", + "#filter_msk = all_df['start'] >= 0\n", + "filter_msk = all_df['start']%1000 == 0\n", + "all_df = all_df[filter_msk]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "itime = time.time()\n", + "pd_list = []\n", + "for ct in all_celltypes:\n", + " tc_chr = all_df[['chr', 'start', 'stop', ct]]\n", + " tc_chr.columns = ['chr', 'start', 'stop', 'y']\n", + " tc_chr['celltype'] = ct\n", + " pd_list.append(tc_chr)\n", + "metadata_df = pd.concat(pd_list)\n", + "print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "itime = time.time()\n", + "y_array = metadata_df['y'].replace({'U': 0, 'B': 1, 'A': -1}).values\n", + "non_ambig_mask = (y_array != -1)\n", + "metadata_df['y'] = y_array\n", + "_metadata_df = metadata_df[non_ambig_mask]\n", + "_y_array = torch.LongTensor(y_array[non_ambig_mask])\n", + "print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "itime = time.time()\n", + "chr_ints = _metadata_df['chr'].replace(dict( [(y, x) for x, y in enumerate(metadata_map['chr'])] )).values\n", + "celltype_ints = _metadata_df['celltype'].replace(dict( [(y, x) for x, y in enumerate(metadata_map['celltype'])] )).values\n", + "print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_chr_mask = np.isin(_metadata_df['chr'], tr_chrs)\n", + "val_chr_mask = np.isin(_metadata_df['chr'], te_chrs)\n", + "train_celltype_mask = np.isin(_metadata_df['celltype'], train_celltypes)\n", + "val_celltype_mask = np.isin(_metadata_df['celltype'], val_celltype)\n", + "test_celltype_mask = np.isin(_metadata_df['celltype'], test_celltype)\n", + "\n", + "split_array = -1*np.ones(_metadata_df.shape[0]).astype(int)\n", + "split_array[np.logical_and(train_chr_mask, train_celltype_mask)] = _split_dict['train']\n", + "split_array[np.logical_and(val_chr_mask, test_celltype_mask)] = _split_dict['test']\n", + "split_array[np.logical_and(val_chr_mask, val_celltype_mask)] = _split_dict['val-ood']\n", + "split_array[np.logical_and(val_chr_mask, train_celltype_mask)] = _split_dict['val-id']\n", + "_metadata_df['split'] = split_array\n", + "_split_array = split_array" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "ename": "ImportError", + "evalue": "No module named data", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mDataLoader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mdataset_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mImportError\u001b[0m: No module named data" + ] + } + ], + "source": [ + "from torch.utils.data import DataLoader\n", + "from data import dataset_attributes" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "from PIL import Image\n", + "import argparse\n", + "class ParseKwargs(argparse.Action):\n", + " def __call__(self, parser, namespace, values, option_string=None):\n", + " setattr(namespace, self.dest, dict())\n", + " for value in values:\n", + " key, value_str = value.split('=')\n", + " if value_str.replace('-','').isnumeric():\n", + " processed_val = int(value_str)\n", + " elif value_str.replace('-','').replace('.','').isnumeric():\n", + " processed_val = float(value_str)\n", + " elif value_str in ['True', 'true']:\n", + " processed_val = True\n", + " elif value_str in ['False', 'false']:\n", + " processed_val = False\n", + " else:\n", + " processed_val = value_str\n", + " getattr(namespace, self.dest)[key] = processed_val" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'algorithm_constructors' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[0;31m# Algorithm and objective\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 33\u001b[0;31m \u001b[0mparser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_argument\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'--algorithm'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrequired\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mchoices\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0malgorithm_constructors\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 34\u001b[0m \u001b[0mparser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_argument\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'--algorithm_kwargs'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnargs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'*'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maction\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mParseKwargs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdefault\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 35\u001b[0m \u001b[0mparser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_argument\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'--groupby_fields'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnargs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'+'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdefault\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mNameError\u001b[0m: name 'algorithm_constructors' is not defined" + ] + } + ], + "source": [ + "ROOTDIR = '/oak/stanford/groups/akundaje/abalsubr/wilds_other'\n", + "args_kw = \"-d camelyon17 --algorithm ERM --model densenet121 --split_scheme standard --groupby_fields hospital --loss_function cross_entropy --optimizer SGD --lr 0.0001 --batch_size 32 --weight_decay 0 --n_epochs 10 --scheduler ReduceLROnPlateau --scheduler_metric_split val --scheduler_metric_name acc_avg --log_dir log --log_every 50 --save_step 1000 --save_best --save_last --seed 0 --evaluate_all_splits --root_dir {}\".format(\n", + " ROOTDIR).split()\n", + "\n", + "parser = argparse.ArgumentParser()\n", + "\n", + "# Dataset\n", + "parser.add_argument('-d', '--dataset', choices=['encodeTFBS', 'amazon', 'camelyon17', 'celebA', 'civilcomments', 'iwildcam', 'waterbirds', 'yelp', 'poverty', 'fmow', 'ogbg-molpcba'], required=True)\n", + "parser.add_argument('--split_scheme', default='standard',\n", + " help='Identifies how the train/val/test split is constructed. Choices are dataset-specific.')\n", + "parser.add_argument('--dataset_kwargs', nargs='*', action=ParseKwargs, default={})\n", + "parser.add_argument('--root_dir', default=None, required=True,\n", + " help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).')\n", + "parser.add_argument('--download', default=False, action='store_true',\n", + " help='If true, tries to downloads the dataset if it does not exist in root_dir.')\n", + "parser.add_argument('--frac', type=float, default=1.0,\n", + " help='Convenience parameter that scales all dataset splits down to the specified fraction, for development purposes.')\n", + "\n", + "# Loaders\n", + "parser.add_argument('--train_loader', choices=['standard', 'group'], default='standard')\n", + "parser.add_argument('--train_loader_kwargs', nargs='*', action=ParseKwargs, default={})\n", + "parser.add_argument('--eval_loader', choices=['standard', 'group'], default='standard')\n", + "parser.add_argument('--batch_size', type=int, default=32)\n", + "parser.add_argument('--no_pin_memory', action='store_true') # TODO: put as loader_kwargs\n", + "parser.add_argument('--num_workers', type=int, default=4) # TODO: put as loader kwargs\n", + "\n", + "# Model\n", + "parser.add_argument(\n", + " '--model',\n", + " choices=['bert-base-uncased', 'inception_v3', 'densenet121', 'wideresnet50', 'resnet50', 'gin-virtual', 'resnet18_ms'],\n", + " default='resnet50')\n", + "parser.add_argument('--model_kwargs', nargs='*', action=ParseKwargs, default={},\n", + " help='keyword arguments for model initialization passed as key1=value1 key2=value2')\n", + "parser.add_argument('--train_from_scratch', action='store_true', default=False)\n", + "\n", + "# Algorithm and objective\n", + "parser.add_argument('--algorithm', required=True, choices=algorithm_constructors.keys())\n", + "parser.add_argument('--algorithm_kwargs', nargs='*', action=ParseKwargs, default={})\n", + "parser.add_argument('--groupby_fields', nargs='+', default=None)\n", + "parser.add_argument('--loss_function', required=True, choices = losses.keys()) #TODO: make default\n", + "parser.add_argument('--val_metric', default=None)\n", + "\n", + "# Optimization\n", + "parser.add_argument('--n_epochs', type=int, default=4)\n", + "parser.add_argument('--optimizer', default=None, choices=optimizer_attributes.keys())\n", + "parser.add_argument('--lr', type=float, required=True)\n", + "parser.add_argument('--weight_decay', type=float, required=True)\n", + "parser.add_argument('--optimizer_kwargs', nargs='*', action=ParseKwargs, default={})\n", + "parser.add_argument('--scheduler', default=None, choices=scheduler_attributes.keys())\n", + "parser.add_argument('--scheduler_kwargs', nargs='*', action=ParseKwargs, default={})\n", + "parser.add_argument('--scheduler_metric_split', choices=['train', 'val'], default='val')\n", + "parser.add_argument('--scheduler_metric_name')\n", + "\n", + "# Evaluation\n", + "parser.add_argument('--evaluate_all_splits', action='store_true', default=False)\n", + "parser.add_argument('--additional_eval_splits', nargs='+', default=[])\n", + "\n", + "# Misc\n", + "parser.add_argument('--device', type=int, default=0)\n", + "parser.add_argument('--seed', type=int, default=0)\n", + "parser.add_argument('--log_dir', default='./logs')\n", + "parser.add_argument('--log_every', default=50, type=int)\n", + "parser.add_argument('--save_step', type=int, default=None)\n", + "parser.add_argument('--save_best', action='store_true', default=False)\n", + "parser.add_argument('--save_last', action='store_true', default=False)\n", + "parser.add_argument('--save_outputs', action='store_true', default=False)\n", + "parser.add_argument('--no_group_logging', action='store_true', default=False)\n", + "parser.add_argument('--val_metric_decreasing', action='store_true', default=False)\n", + "parser.add_argument('--use_wandb', action='store_true', default=False)\n", + "parser.add_argument('--progress_bar', action='store_true', default=False)\n", + "parser.add_argument('--resume', default=False, action='store_true')\n", + "parser.add_argument('--eval_only', default=False, action='store_true')\n", + "\n", + "args = parser.parse_args(args_kw)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# get_input (idx)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name '_metadata_df' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0midx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mthis_metadata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_metadata_df\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miloc\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mitime\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mflank_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m400\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mNameError\u001b[0m: name '_metadata_df' is not defined" + ] + } + ], + "source": [ + "idx = 3\n", + "this_metadata = _metadata_df.iloc[idx, :]\n", + "\n", + "itime = time.time()\n", + "flank_size = 400\n", + "interval_start = this_metadata['start'] - flank_size\n", + "interval_end = this_metadata['stop'] + flank_size\n", + "dnase_this = _dnase_allcelltypes[this_metadata['celltype']][this_metadata['chr']][interval_start:interval_end]\n", + "seq_this = _seq_bp[this_metadata['chr']][interval_start:interval_end]\n", + "data = np.column_stack([seq_this, dnase_this])\n", + "# print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.028102874755859375\n" + ] + } + ], + "source": [ + "itime = time.time()\n", + "metadata_array = torch.stack(\n", + " (torch.LongTensor(chr_ints), \n", + " torch.LongTensor(celltype_ints), \n", + " _y_array),\n", + " dim=1)\n", + "print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [ + { + "ename": "ModuleNotFoundError", + "evalue": "No module named 'torch_scatter'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m#data.shape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mwilds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata_loaders\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mget_train_loader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mget_eval_loader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m~/dr_benchmark/wilds/common/data_loaders.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mDataLoader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msampler\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mWeightedRandomSampler\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mSubsetRandomSampler\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mwilds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mget_counts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msplit_into_groups\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mIPython\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/dr_benchmark/wilds/common/utils.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch_scatter\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnumpy\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mSubset\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mpandas\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapi\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtypes\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mCategoricalDtype\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'torch_scatter'" + ] + } + ], + "source": [ + "#data.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 157, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "4600" + ] + }, + "execution_count": 157, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data.shape\n", + "interval_end\n", + "# itime = time.time()\n", + "# np.save(os.path.join(data_dir, 'stmp.npy'), sa)\n", + "# print(time.time() - itime)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Run training experiment" + ] + }, + { + "cell_type": "code", + "execution_count": 167, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'python3 examples/run_expt.py -d encodeTFBS --algorithm ERM --model densenet121 --split_scheme standard --groupby_fields hospital --loss_function cross_entropy --optimizer SGD --lr 0.0001 --batch_size 32 --weight_decay 0 --n_epochs 10 --scheduler ReduceLROnPlateau --scheduler_metric_split val --scheduler_metric_name acc_avg --log_dir log --log_every 50 --save_step 1000 --save_best --save_last --seed 0 --evaluate_all_splits --root_dir ROOTDIR'" + ] + }, + "execution_count": 167, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cmdstr = \"python3 examples/run_expt.py -d encodeTFBS --algorithm ERM --model densenet121 --split_scheme standard --groupby_fields hospital --loss_function cross_entropy\"\n", + "cmdstr += \" \"\n", + "cmdstr += \"--optimizer SGD --lr 0.0001 --batch_size 32 --weight_decay 0 --n_epochs 10 --scheduler ReduceLROnPlateau --scheduler_metric_split val --scheduler_metric_name acc_avg\"\n", + "cmdstr += \" \"\n", + "cmdstr += \"--log_dir log --log_every 50 --save_step 1000 --save_best --save_last --seed 0 --evaluate_all_splits --root_dir ROOTDIR\"\n", + "cmdstr" + ] + }, + { + "cell_type": "code", + "execution_count": 164, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name '_metadata_array' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0m_metadata_array\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mNameError\u001b[0m: name '_metadata_array' is not defined" + ] + } + ], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 165, + "metadata": {}, + "outputs": [ + { + "ename": "ModuleNotFoundError", + "evalue": "No module named 'torch_scatter'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minsert\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'..'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 16\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mwilds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata_loaders\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mget_train_loader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mget_eval_loader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 17\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mwilds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrouper\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mCombinatorialGrouper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mwilds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mget_counts\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/dr_benchmark/wilds/common/data_loaders.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mDataLoader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msampler\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mWeightedRandomSampler\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mSubsetRandomSampler\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mwilds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mget_counts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msplit_into_groups\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mIPython\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/dr_benchmark/wilds/common/utils.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch_scatter\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnumpy\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mSubset\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mpandas\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapi\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtypes\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mCategoricalDtype\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'torch_scatter'" + ] + } + ], + "source": [ + "import os, csv\n", + "import time\n", + "import argparse\n", + "import IPython\n", + "import pandas as pd\n", + "import torch\n", + "import torch.nn as nn\n", + "import torchvision\n", + "import sys\n", + "from collections import defaultdict\n", + "# torch.multiprocessing.set_sharing_strategy('file_system')\n", + "\n", + "# TODO: Replace this once we make wilds into an installed package\n", + "sys.path.insert(1, os.path.join(sys.path[0], '..'))\n", + "\n", + "from wilds.common.data_loaders import get_train_loader, get_eval_loader\n", + "from wilds.common.grouper import CombinatorialGrouper\n", + "from wilds.common.utils import get_counts\n", + "\n", + "from models.model_attributes import model_attributes\n", + "from utils import set_seed, Logger, BatchLogger, log_args, ParseKwargs, load\n", + "from train import train, evaluate\n", + "from data import dataset_attributes\n", + "from optimizer import optimizer_attributes\n", + "from scheduler import scheduler_attributes\n", + "from loss import losses\n", + "from utils import log_group_data\n", + "from algorithms.constructors import algorithm_constructors" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from examples.models.model_attributes import model_attributes" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "ename": "ModuleNotFoundError", + "evalue": "No module named 'utils'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel_attributes\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mmodel_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mset_seed\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mLogger\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mCSVBatchLogger\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlog_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mParseKwargs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mload\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 21\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 22\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mdataset_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0moptimizer_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/dr_benchmark/examples/train.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0msave\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mIPython\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'utils'" + ] + } + ], + "source": [ + "def initialize_algorithm(args, datasets, train_grouper):\n", + " train_dataset = datasets['train']['dataset']\n", + " train_loader = datasets['train']['loader']\n", + "\n", + " # Configure the final layer of the networks used\n", + " # The code below are defaults. Edit this if you need special config for your model.\n", + " if (train_dataset.is_classification) and (train_dataset.y_size == 1):\n", + " # For single-task classification, we have one output per class\n", + " d_out = train_dataset.n_classes\n", + " elif (train_dataset.is_classification) and (train_dataset.y_size > 1) and (train_dataset.n_classes == 2):\n", + " # For multi-task binary classification (each output is the logit for each binary class)\n", + " d_out = train_dataset.y_size\n", + " elif (not train_dataset.is_classification):\n", + " # For regression, we have one output per target dimension\n", + " d_out = train_dataset.y_size\n", + " else:\n", + " raise RuntimeError('d_out not defined.')\n", + " \n", + "\n", + " # Sanity checking input args\n", + " if args.algorithm == 'groupDRO':\n", + " assert args.train_loader_kwargs['uniform_over_groups']\n", + " elif args.algorithm in ['deepCORAL', 'IRM']:\n", + " assert args.train_loader == 'group'\n", + " assert args.train_loader_kwargs['uniform_over_groups']\n", + " assert args.train_loader_kwargs['distinct_groups']\n", + "\n", + " # Other config\n", + " n_train_steps = len(train_loader) * args.n_epochs\n", + "# prediction_fn = dataset_attributes[args.dataset]['prediction_fn']\n", + " loss = losses[args.loss_function]\n", + " metric = dataset_attributes[args.dataset]['metric']\n", + " train_g = train_grouper.metadata_to_group(train_dataset.metadata_array)\n", + " is_group_in_train = get_counts(train_g, train_grouper.n_groups) > 0\n", + " algorithm_constructor = algorithm_constructors[args.algorithm]\n", + " algorithm = algorithm_constructor(\n", + " args=args,\n", + " d_out=d_out,\n", + " grouper=train_grouper,\n", + " loss=loss,\n", + " metric=metric,\n", + " n_train_steps=n_train_steps,\n", + " is_group_in_train=is_group_in_train)\n", + " return algorithm" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def main():\n", + " parser = argparse.ArgumentParser()\n", + "\n", + " # Dataset\n", + " parser.add_argument('-d', '--dataset', choices=dataset_attributes.keys(), required=True)\n", + " parser.add_argument('--split_scheme', default='standard',\n", + " help='Identifies how the train/val/test split is constructed. Choices are dataset-specific.')\n", + " parser.add_argument('--dataset_kwargs', nargs='*', action=ParseKwargs, default={})\n", + " parser.add_argument('--root_dir', default=None, required=True,\n", + " help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).')\n", + " parser.add_argument('--download', default=False, action='store_true',\n", + " help='If true, tries to downloads the dataset if it does not exist in root_dir.')\n", + " parser.add_argument('--frac', type=float, default=1.0,\n", + " help='Convenience parameter that scales all dataset splits down to the specified fraction, for development purposes.')\n", + "\n", + " # Loaders\n", + " parser.add_argument('--train_loader', choices=['standard', 'group'], default='standard')\n", + " parser.add_argument('--train_loader_kwargs', nargs='*', action=ParseKwargs, default={})\n", + " parser.add_argument('--eval_loader', choices=['standard', 'group'], default='standard')\n", + " parser.add_argument('--batch_size', type=int, default=32)\n", + " parser.add_argument('--no_pin_memory', action='store_true') # TODO: put as loader_kwargs\n", + " parser.add_argument('--num_workers', type=int, default=4) # TODO: put as loader kwargs\n", + "\n", + " # Model\n", + " parser.add_argument(\n", + " '--model',\n", + " choices=model_attributes.keys(),\n", + " default='resnet50')\n", + " parser.add_argument('--model_kwargs', nargs='*', action=ParseKwargs, default={},\n", + " help='keyword arguments for model initialization passed as key1=value1 key2=value2')\n", + " parser.add_argument('--train_from_scratch', action='store_true', default=False)\n", + "\n", + " # Algorithm and objective\n", + " parser.add_argument('--algorithm', required=True, choices=algorithm_constructors.keys())\n", + " parser.add_argument('--algorithm_kwargs', nargs='*', action=ParseKwargs, default={})\n", + " parser.add_argument('--groupby_fields', nargs='+', default=None)\n", + " parser.add_argument('--loss_function', required=True, choices = losses.keys()) #TODO: make default\n", + " parser.add_argument('--val_metric', default=None)\n", + "\n", + " # Optimization\n", + " parser.add_argument('--n_epochs', type=int, default=4)\n", + " parser.add_argument('--optimizer', default=None, choices=optimizer_attributes.keys())\n", + " parser.add_argument('--lr', type=float, required=True)\n", + " parser.add_argument('--weight_decay', type=float, required=True)\n", + " parser.add_argument('--optimizer_kwargs', nargs='*', action=ParseKwargs, default={})\n", + " parser.add_argument('--scheduler', default=None, choices=scheduler_attributes.keys())\n", + " parser.add_argument('--scheduler_kwargs', nargs='*', action=ParseKwargs, default={})\n", + " parser.add_argument('--scheduler_metric_split', choices=['train', 'val'], default='val')\n", + " parser.add_argument('--scheduler_metric_name')\n", + "\n", + " # Evaluation\n", + " parser.add_argument('--evaluate_all_splits', action='store_true', default=False)\n", + " parser.add_argument('--additional_eval_splits', nargs='+', default=[])\n", + "\n", + " # Misc\n", + " parser.add_argument('--device', type=int, default=0)\n", + " parser.add_argument('--seed', type=int, default=0)\n", + " parser.add_argument('--log_dir', default='./logs')\n", + " parser.add_argument('--log_every', default=50, type=int)\n", + " parser.add_argument('--save_step', type=int, default=None)\n", + " parser.add_argument('--save_best', action='store_true', default=False)\n", + " parser.add_argument('--save_last', action='store_true', default=False)\n", + " parser.add_argument('--save_outputs', action='store_true', default=False)\n", + " parser.add_argument('--no_group_logging', action='store_true', default=False)\n", + " parser.add_argument('--val_metric_decreasing', action='store_true', default=False)\n", + " parser.add_argument('--use_wandb', action='store_true', default=False)\n", + " parser.add_argument('--progress_bar', action='store_true', default=False)\n", + " parser.add_argument('--resume', default=False, action='store_true')\n", + " parser.add_argument('--eval_only', default=False, action='store_true')\n", + "\n", + " args = parser.parse_args()\n", + "\n", + " # set device\n", + " args.device = torch.device(\"cuda:\" + str(args.device)) if torch.cuda.is_available() else torch.device(\"cpu\")\n", + "\n", + " # Set defaults\n", + " if args.groupby_fields is None:\n", + " args.no_group_logging = True\n", + " if args.val_metric is None:\n", + " args.val_metric = dataset_attributes[args.dataset]['val_metric']\n", + "\n", + " ## Initialize logs\n", + " if os.path.exists(args.log_dir) and args.resume:\n", + " resume=True\n", + " mode='a'\n", + " else:\n", + " resume=False\n", + " mode='w'\n", + " if not os.path.exists(args.log_dir):\n", + " os.makedirs(args.log_dir)\n", + " logger = Logger(os.path.join(args.log_dir, 'log.txt'), mode)\n", + "\n", + " # Record args\n", + " log_args(args, logger)\n", + "\n", + " # Set random seed\n", + " set_seed(args.seed)\n", + "\n", + " # Data\n", + " full_dataset = dataset_attributes[args.dataset]['constructor'](\n", + " root_dir=args.root_dir,\n", + " download=args.download,\n", + " split_scheme=args.split_scheme,\n", + " **args.dataset_kwargs)\n", + "\n", + " # To implement data augmentation (i.e., have different transforms\n", + " # at training time vs. test time), modify these two lines:\n", + " train_transform = dataset_attributes[args.dataset]['transform'](args.model)\n", + " if dataset_attributes[args.dataset].get('eval_transform') is None:\n", + " eval_transform = dataset_attributes[args.dataset]['transform'](args.model)\n", + " else:\n", + " eval_transform = dataset_attributes[args.dataset]['eval_transform'](args.model)\n", + "\n", + " train_grouper = CombinatorialGrouper(\n", + " dataset=full_dataset,\n", + " groupby_fields=args.groupby_fields)\n", + "\n", + " datasets = defaultdict(dict)\n", + " for split in full_dataset.split_dict.keys():\n", + " if split=='train':\n", + " transform = train_transform\n", + " verbose = True\n", + " elif split == 'val':\n", + " transform = eval_transform\n", + " verbose = True\n", + " else:\n", + " transform = eval_transform\n", + " verbose = False\n", + " # Get subset\n", + " datasets[split]['dataset'] = full_dataset.get_subset(\n", + " split,\n", + " frac=args.frac,\n", + " transform=transform)\n", + "\n", + " # Get loader\n", + " shared_loader_kwargs = {\n", + " 'num_workers': args.num_workers,\n", + " 'pin_memory': not args.no_pin_memory,\n", + " 'batch_size': args.batch_size,\n", + " 'collate_fn': dataset_attributes[args.dataset]['collate']\n", + " }\n", + "\n", + " if split == 'train':\n", + " datasets[split]['loader'] = get_train_loader(\n", + " loader=args.train_loader,\n", + " dataset=datasets[split]['dataset'],\n", + " grouper=train_grouper,\n", + " train_loader_kwargs=args.train_loader_kwargs,\n", + " **shared_loader_kwargs)\n", + " else:\n", + " datasets[split]['loader'] = get_eval_loader(\n", + " loader=args.eval_loader,\n", + " dataset=datasets[split]['dataset'],\n", + " grouper=train_grouper,\n", + " **shared_loader_kwargs)\n", + "\n", + " # Set fields\n", + " datasets[split]['split'] = split\n", + " datasets[split]['name'] = full_dataset.split_names[split]\n", + " datasets[split]['verbose'] = verbose\n", + " # Loggers\n", + " # Loggers\n", + " datasets[split]['eval_logger'] = BatchLogger(\n", + " os.path.join(args.log_dir, f'{split}_eval.csv'), mode=mode, use_wandb=args.use_wandb)\n", + " datasets[split]['algo_logger'] = BatchLogger(\n", + " os.path.join(args.log_dir, f'{split}_algo.csv'), mode=mode, use_wandb=args.use_wandb)\n", + "\n", + " if args.use_wandb:\n", + " initialize_wandb(args)\n", + "\n", + " # Logging dataset info\n", + " if args.no_group_logging and full_dataset.is_classification and full_dataset.y_size==1:\n", + " log_grouper = CombinatorialGrouper(\n", + " dataset=full_dataset,\n", + " groupby_fields=['y'])\n", + " elif args.no_group_logging:\n", + " log_grouper = None\n", + " else:\n", + " log_grouper = train_grouper\n", + " log_group_data(args, datasets, log_grouper, logger)\n", + "\n", + " ## Initialize algorithm\n", + " algorithm = initialize_algorithm(args, datasets, train_grouper)\n", + "\n", + " if not args.eval_only:\n", + " ## Load saved results if resuming\n", + " resume_success = False\n", + " if resume:\n", + " save_path = os.path.join(args.log_dir, 'last_model.pth')\n", + " if not os.path.exists(save_path):\n", + " epochs = [\n", + " int(file.split('_')[0])\n", + " for file in os.listdir(args.log_dir) if file.endswith('.pth')]\n", + " if len(epochs) > 0:\n", + " latest_epoch = max(epochs)\n", + " save_path = os.path.join(args.log_dir, f'{latest_epoch}_model.pth')\n", + " try:\n", + " prev_epoch, best_val_metric = load(algorithm, save_path)\n", + " epoch_offset = prev_epoch + 1\n", + " logger.write(f'Resuming from epoch {epoch_offset} with best val metric {best_val_metric}')\n", + " resume_success = True\n", + " except FileNotFoundError:\n", + " pass\n", + "\n", + " if resume_success == False:\n", + " epoch_offset=0\n", + " best_val_metric=None\n", + "\n", + "\n", + " train(algorithm,\n", + " datasets,\n", + " logger,\n", + " args,\n", + " epoch_offset=epoch_offset,\n", + " best_val_metric=best_val_metric)\n", + " else:\n", + " best_model_path = os.path.join(args.log_dir, 'best_model.pth')\n", + " best_epoch, best_val_metric = load(algorithm, best_model_path)\n", + " evaluate(algorithm, datasets, best_epoch, logger)\n", + "\n", + " logger.close()\n", + " for split in datasets:\n", + " datasets[split]['eval_logger'].close()\n", + " datasets[split]['algo_logger'].close()\n", + "\n", + "if __name__=='__main__':\n", + " main()\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 2", + "language": "python", + "name": "python2" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/sandbox_model.ipynb b/sandbox_model.ipynb new file mode 100644 index 00000000..c264d747 --- /dev/null +++ b/sandbox_model.ipynb @@ -0,0 +1,876 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Initialize dataset object" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "57.8772239685\n", + "66.8270189762\n" + ] + } + ], + "source": [ + "import numpy as np, pandas as pd, os, time, torch, torchvision\n", + "data_dir = '/oak/stanford/groups/akundaje/abalsubr/DREAM/wilds/codalab_archive/'\n", + "tf = 'MAX'\n", + "itime = time.time()\n", + "train_chr = pd.read_csv(os.path.join(data_dir, 'labels/{}.train.labels.tsv.gz'.format(tf)), sep='\\t')\n", + "print(time.time() - itime)\n", + "val_chr = pd.read_csv(os.path.join(data_dir, 'labels/{}.val.labels.tsv.gz'.format(tf)), sep='\\t')\n", + "print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']\n", + "val_celltype = ['A549']\n", + "test_celltype = ['GM12878']\n", + "all_celltypes = train_celltypes + val_celltype + test_celltype\n", + "\n", + "metadata_map = {}\n", + "metadata_map['chr'] = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX']\n", + "metadata_map['celltype'] = all_celltypes\n", + "\n", + "_split_dict = {\n", + " 'train': 0,\n", + " 'val-id': 1,\n", + " 'test': 2,\n", + " 'val-ood': 3\n", + "}\n", + "_split_names = {\n", + " 'train': 'Train',\n", + " 'val-id': 'Validation (ID)',\n", + " 'test': 'Test',\n", + " 'val-ood': 'Validation (OOD)'\n", + "}\n", + "_split_scheme = 'standard'" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "('H1-hESC', 25.299736976623535)\n", + "('HCT116', 49.68733310699463)\n", + "('HeLa-S3', 74.65905213356018)\n", + "('HepG2', 99.33112812042236)\n", + "('K562', 124.1327919960022)\n", + "('A549', 149.19999814033508)\n", + "('GM12878', 174.0277030467987)\n" + ] + } + ], + "source": [ + "itime = time.time()\n", + "sequence_filename = os.path.join(data_dir, 'sequence.npz')\n", + "seq_arr = np.load(sequence_filename)\n", + "print(time.time() - itime)\n", + "\n", + "itime = time.time()\n", + "_seq_bp = {}\n", + "for chrom in seq_arr:\n", + " _seq_bp[chrom] = seq_arr[chrom]\n", + " print(chrom, time.time() - itime)\n", + "itime = time.time()\n", + "_dnase_allcelltypes = {}\n", + "for ct in all_celltypes:\n", + " dnase_filename = os.path.join(data_dir, '{}_dnase.npz'.format(ct))\n", + " dnase_npz_file = np.load(dnase_filename)\n", + " _dnase_allcelltypes[ct] = {}\n", + " for chrom in _seq_bp:\n", + " _dnase_allcelltypes[ct][chrom] = dnase_npz_file[chrom]\n", + " print(ct, time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'A549': {'chr1': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr10': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr11': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr12': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr13': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr14': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr15': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr16': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr17': array([ 0.35986328, 0.35986328, 0.35986328, ..., 0. ,\n", + " 0. , 0. ], dtype=float16),\n", + " 'chr18': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr19': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr2': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr20': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr21': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr22': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr3': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr4': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr5': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr6': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr7': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr8': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr9': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chrX': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16)},\n", + " 'GM12878': {'chr1': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr10': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr11': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr12': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr13': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr14': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr15': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr16': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr17': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr18': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr19': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr2': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr20': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr21': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr22': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr3': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr4': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr5': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr6': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr7': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr8': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr9': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chrX': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16)},\n", + " 'H1-hESC': {'chr1': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr10': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr11': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr12': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr13': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr14': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr15': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr16': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr17': array([ 0.71972656, 0.71972656, 0.71972656, ..., 0. ,\n", + " 0. , 0. ], dtype=float16),\n", + " 'chr18': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr19': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr2': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr20': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr21': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr22': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr3': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr4': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr5': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr6': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr7': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr8': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr9': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chrX': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16)},\n", + " 'HCT116': {'chr1': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr10': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr11': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr12': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr13': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr14': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr15': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr16': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr17': array([ 0.80419922, 0.80419922, 0.80419922, ..., 0. ,\n", + " 0. , 0. ], dtype=float16),\n", + " 'chr18': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr19': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr2': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr20': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr21': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr22': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr3': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr4': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr5': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr6': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr7': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr8': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr9': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chrX': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16)},\n", + " 'HeLa-S3': {'chr1': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr10': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr11': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr12': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr13': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr14': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr15': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr16': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr17': array([ 0.71972656, 0.71972656, 0.71972656, ..., 0. ,\n", + " 0. , 0. ], dtype=float16),\n", + " 'chr18': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr19': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr2': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr20': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr21': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr22': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr3': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr4': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr5': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr6': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr7': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr8': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr9': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chrX': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16)},\n", + " 'HepG2': {'chr1': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr10': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr11': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr12': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr13': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr14': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr15': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr16': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr17': array([ 0.71972656, 0.71972656, 0.71972656, ..., 0. ,\n", + " 0. , 0. ], dtype=float16),\n", + " 'chr18': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr19': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr2': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr20': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr21': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr22': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr3': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr4': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr5': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr6': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr7': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr8': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr9': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chrX': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16)},\n", + " 'K562': {'chr1': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr10': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr11': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr12': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr13': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr14': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr15': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr16': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr17': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr18': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr19': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr2': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr20': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr21': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr22': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr3': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr4': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr5': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr6': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr7': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr8': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr9': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chrX': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16)}}" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "_dnase_allcelltypes" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "tr_chrs = ['chr2', 'chr9', 'chr11']\n", + "te_chrs = ['chr1', 'chr8', 'chr21']\n", + "training_df = train_chr[np.isin(train_chr['chr'], tr_chrs)]\n", + "val_df = val_chr[np.isin(val_chr['chr'], te_chrs)]\n", + "all_df = pd.concat([training_df, val_df])\n", + "\n", + "#filter_msk = all_df['start'] >= 0\n", + "filter_msk = all_df['start']%1000 == 0\n", + "all_df = all_df[filter_msk]" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/users/abalsubr/anaconda2/envs/scs3/lib/python3.6/site-packages/ipykernel_launcher.py:6: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy\n", + " \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1.659163236618042\n" + ] + } + ], + "source": [ + "itime = time.time()\n", + "pd_list = []\n", + "for ct in all_celltypes:\n", + " tc_chr = all_df[['chr', 'start', 'stop', ct]]\n", + " tc_chr.columns = ['chr', 'start', 'stop', 'y']\n", + " tc_chr['celltype'] = ct\n", + " pd_list.append(tc_chr)\n", + "metadata_df = pd.concat(pd_list)\n", + "print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3.0391879081726074\n" + ] + } + ], + "source": [ + "itime = time.time()\n", + "y_array = metadata_df['y'].replace({'U': 0, 'B': 1, 'A': -1}).values\n", + "non_ambig_mask = (y_array != -1)\n", + "metadata_df['y'] = y_array\n", + "_metadata_df = metadata_df[non_ambig_mask]\n", + "_y_array = torch.LongTensor(y_array[non_ambig_mask])\n", + "print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "12.390011310577393\n" + ] + } + ], + "source": [ + "itime = time.time()\n", + "chr_ints = _metadata_df['chr'].replace(dict( [(y, x) for x, y in enumerate(metadata_map['chr'])] )).values\n", + "celltype_ints = _metadata_df['celltype'].replace(dict( [(y, x) for x, y in enumerate(metadata_map['celltype'])] )).values\n", + "print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/users/abalsubr/anaconda2/envs/scs3/lib/python3.6/site-packages/ipykernel_launcher.py:12: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy\n", + " if sys.path[0] == '':\n" + ] + } + ], + "source": [ + "train_chr_mask = np.isin(_metadata_df['chr'], tr_chrs)\n", + "val_chr_mask = np.isin(_metadata_df['chr'], te_chrs)\n", + "train_celltype_mask = np.isin(_metadata_df['celltype'], train_celltypes)\n", + "val_celltype_mask = np.isin(_metadata_df['celltype'], val_celltype)\n", + "test_celltype_mask = np.isin(_metadata_df['celltype'], test_celltype)\n", + "\n", + "split_array = -1*np.ones(_metadata_df.shape[0]).astype(int)\n", + "split_array[np.logical_and(train_chr_mask, train_celltype_mask)] = _split_dict['train']\n", + "split_array[np.logical_and(val_chr_mask, test_celltype_mask)] = _split_dict['test']\n", + "split_array[np.logical_and(val_chr_mask, val_celltype_mask)] = _split_dict['val-ood']\n", + "split_array[np.logical_and(val_chr_mask, train_celltype_mask)] = _split_dict['val-id']\n", + "_metadata_df['split'] = split_array\n", + "_split_array = split_array" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# get_input (idx)" + ] + }, + { + "cell_type": "code", + "execution_count": 153, + "metadata": {}, + "outputs": [], + "source": [ + "idx = 3\n", + "this_metadata = _metadata_df.iloc[idx, :]\n", + "\n", + "itime = time.time()\n", + "flank_size = 400\n", + "interval_start = this_metadata['start'] - flank_size\n", + "interval_end = this_metadata['stop'] + flank_size\n", + "dnase_this = _dnase_allcelltypes[this_metadata['celltype']][this_metadata['chr']][interval_start:interval_end]\n", + "seq_this = _seq_bp[this_metadata['chr']][interval_start:interval_end]\n", + "data = np.column_stack([seq_this, dnase_this])\n", + "# print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 154, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "4600" + ] + }, + "execution_count": 154, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data.shape\n", + "interval_end\n", + "# itime = time.time()\n", + "# np.save(os.path.join(data_dir, 'stmp.npy'), sa)\n", + "# print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 78, + "metadata": {}, + "outputs": [ + { + "ename": "TypeError", + "evalue": "can't convert np.ndarray of type numpy.object_. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mitime\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m metadata_array = torch.stack(\n\u001b[0;32m----> 3\u001b[0;31m (torch.LongTensor(metadata_df['chr'].values), \n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mLongTensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmetadata_df\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'celltype'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m self._y_array),\n", + "\u001b[0;31mTypeError\u001b[0m: can't convert np.ndarray of type numpy.object_. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool." + ] + } + ], + "source": [ + "itime = time.time()\n", + "metadata_array = torch.stack(\n", + " (torch.LongTensor(chr_ints), \n", + " torch.LongTensor(celltype_ints), \n", + " _y_array),\n", + " dim=1)\n", + "print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "python3 examples/run_expt.py -d camelyon17 --algorithm ERM --model densenet121 --split_scheme standard --groupby_fields hospital --loss_function cross_entropy --optimizer SGD \n", + "--lr 0.0001 --batch_size 32 --weight_decay 0 --n_epochs 10 --scheduler ReduceLROnPlateau --scheduler_metric_split val --scheduler_metric_name acc_avg \n", + "--log_dir log --log_every 50 --save_step 1000 --save_best --save_last --seed 0 --evaluate_all_splits --root_dir ROOTDIR" + ] + }, + { + "cell_type": "code", + "execution_count": 156, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name '_metadata_array' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0m_metadata_array\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mNameError\u001b[0m: name '_metadata_array' is not defined" + ] + } + ], + "source": [ + "_metadata_array" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from examples.models.model_attributes import model_attributes" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "ename": "ModuleNotFoundError", + "evalue": "No module named 'utils'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel_attributes\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mmodel_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mset_seed\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mLogger\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mCSVBatchLogger\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlog_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mParseKwargs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mload\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 21\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 22\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mdataset_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0moptimizer_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/dr_benchmark/examples/train.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0msave\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mIPython\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'utils'" + ] + } + ], + "source": [ + "import os, csv\n", + "import time\n", + "import argparse\n", + "import IPython\n", + "import pandas as pd\n", + "import torch\n", + "import torch.nn as nn\n", + "import torchvision\n", + "import sys\n", + "from collections import defaultdict\n", + "\n", + "# TODO: Replace this once we make wilds into an installed package\n", + "sys.path.insert(1, os.path.join(sys.path[0], '..'))\n", + "\n", + "from wilds.common.data_loaders import get_train_loader, get_eval_loader\n", + "from wilds.common.grouper import CombinatorialGrouper\n", + "from wilds.common.utils import get_counts\n", + "\n", + "from examples.models.model_attributes import model_attributes\n", + "from examples.utils import set_seed, Logger, CSVBatchLogger, log_args, ParseKwargs, load\n", + "from examples.train import train\n", + "from examples.data import dataset_attributes\n", + "from examples.optimizer import optimizer_attributes\n", + "from examples.scheduler import scheduler_attributes\n", + "from examples.loss import losses\n", + "from examples.utils import log_group_data\n", + "from examples.algorithms.constructors import algorithm_constructors\n", + "\n", + "\n", + "def initialize_algorithm(args, datasets, train_grouper):\n", + " train_dataset = datasets['train']['dataset']\n", + " train_loader = datasets['train']['loader']\n", + "\n", + " # Configure the final layer of the networks used\n", + " # The code below are defaults. Edit this if you need special config for your model.\n", + " if (train_dataset.is_classification) and (train_dataset.y_size == 1):\n", + " # For single-task classification, we have one output per class\n", + " d_out = train_dataset.n_classes\n", + " elif (not train_dataset.is_classification):\n", + " # For regression, we have one output per target dimension\n", + " d_out = train_dataset.y_size\n", + " else:\n", + " # TODO: Handle dataset-specific multi-task stuff here, e.g., for OGB\n", + " pass\n", + "\n", + " # Sanity checking input args\n", + " if args.algorithm == 'groupDRO':\n", + " assert args.train_loader_kwargs['uniform_over_groups']\n", + " elif args.algorithm in ['deepCORAL', 'IRM']:\n", + " assert args.train_loader == 'group'\n", + " assert args.train_loader_kwargs['uniform_over_groups']\n", + " assert args.train_loader_kwargs['distinct_groups']\n", + "\n", + " # Other config\n", + " n_train_steps = len(train_loader) * args.n_epochs\n", + " prediction_fn = dataset_attributes[args.dataset]['prediction_fn']\n", + " loss = losses[args.loss_function]\n", + " metric_constructor = dataset_attributes[args.dataset]['metric']\n", + " train_g = train_grouper.metadata_to_group(train_dataset.metadata_array)\n", + " is_group_in_train = get_counts(train_g, train_grouper.n_groups) > 0\n", + " algorithm_constructor = algorithm_constructors[args.algorithm]\n", + " algorithm = algorithm_constructor(\n", + " args=args,\n", + " d_out=d_out,\n", + " grouper=train_grouper,\n", + " prediction_fn=prediction_fn,\n", + " loss=loss,\n", + " metric_constructor=metric_constructor,\n", + " n_train_steps=n_train_steps,\n", + " is_group_in_train=is_group_in_train)\n", + " return algorithm" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "parser = argparse.ArgumentParser()\n", + "\n", + "# Dataset\n", + "parser.add_argument('-d', '--dataset', choices=dataset_attributes.keys(), required=True)\n", + "parser.add_argument('--split_scheme', default='standard',\n", + " help='Identifies how the train/val/test split is constructed. Choices are dataset-specific.')\n", + "parser.add_argument('--root_dir', default=None, required=True,\n", + " help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).')\n", + "parser.add_argument('--download', default=False, action='store_true',\n", + " help='If true, tries to downloads the dataset if it does not exist in root_dir.')\n", + "parser.add_argument('--frac', type=float, default=1.0,\n", + " help='Convenience parameter that scales all dataset splits down to the specified fraction, for development purposes.')\n", + "\n", + "# Loaders\n", + "parser.add_argument('--train_loader', choices=['standard', 'group'], default='standard')\n", + "parser.add_argument('--train_loader_kwargs', nargs='*', action=ParseKwargs, default={})\n", + "parser.add_argument('--eval_loader', choices=['standard', 'group'], default='standard')\n", + "parser.add_argument('--batch_size', type=int, default=32)\n", + "\n", + "# Model\n", + "parser.add_argument(\n", + " '--model',\n", + " choices=model_attributes.keys(),\n", + " default='resnet50')\n", + "parser.add_argument('--model_kwargs', nargs='*', action=ParseKwargs, default={},\n", + " help='keyword arguments for model initialization passed as key1=value1 key2=value2')\n", + "parser.add_argument('--train_from_scratch', action='store_true', default=False)\n", + "\n", + "# Algorithm and objective\n", + "parser.add_argument('--algorithm', required=True, choices=algorithm_constructors.keys())\n", + "parser.add_argument('--algorithm_kwargs', nargs='*', action=ParseKwargs, default={})\n", + "parser.add_argument('--groupby_fields', nargs='+', default=None)\n", + "parser.add_argument('--loss_function', required=True, choices = losses.keys()) #TODO: make default\n", + "parser.add_argument('--val_metric', default=None)\n", + "\n", + "# Optimization\n", + "parser.add_argument('--n_epochs', type=int, default=4)\n", + "parser.add_argument('--optimizer', default=None, choices=optimizer_attributes.keys())\n", + "parser.add_argument('--lr', type=float, required=True)\n", + "parser.add_argument('--weight_decay', type=float, required=True)\n", + "parser.add_argument('--optimizer_kwargs', nargs='*', action=ParseKwargs, default={})\n", + "parser.add_argument('--scheduler', default=None, choices=scheduler_attributes.keys())\n", + "parser.add_argument('--scheduler_kwargs', nargs='*', action=ParseKwargs, default={})\n", + "parser.add_argument('--scheduler_metric_split', choices=['train', 'val'], default='val')\n", + "parser.add_argument('--scheduler_metric_name')\n", + "\n", + "# Evaluation\n", + "parser.add_argument('--evaluate_all_splits', action='store_true', default=False)\n", + "parser.add_argument('--additional_eval_splits', nargs='+', default=[])\n", + "\n", + "# Misc\n", + "parser.add_argument('--device', default='cuda')\n", + "parser.add_argument('--seed', type=int, default=0)\n", + "parser.add_argument('--log_dir', default='./logs')\n", + "parser.add_argument('--log_every', default=50, type=int)\n", + "parser.add_argument('--save_step', type=int, default=None)\n", + "parser.add_argument('--save_best', action='store_true', default=False)\n", + "parser.add_argument('--save_last', action='store_true', default=False)\n", + "parser.add_argument('--save_outputs', action='store_true', default=False)\n", + "parser.add_argument('--no_group_logging', action='store_true', default=False)\n", + "\n", + "parser.add_argument('--resume', default=False, action='store_true')\n", + "\n", + "args = parser.parse_args()\n", + "\n", + "# Set defaults\n", + "if args.groupby_fields is None:\n", + " args.no_group_logging = True\n", + "if args.val_metric is None:\n", + " args.val_metric = dataset_attributes[args.dataset]['val_metric']\n", + "\n", + "## Initialize logs\n", + "if os.path.exists(args.log_dir) and args.resume:\n", + " resume=True\n", + " mode='a'\n", + "else:\n", + " resume=False\n", + " mode='w'\n", + "if not os.path.exists(args.log_dir):\n", + " os.makedirs(args.log_dir)\n", + "logger = Logger(os.path.join(args.log_dir, 'log.txt'), mode)\n", + "\n", + "# Record args\n", + "log_args(args, logger)\n", + "\n", + "# Set random seed\n", + "set_seed(args.seed)\n", + "\n", + "# Data\n", + "full_dataset = dataset_attributes[args.dataset]['constructor'](\n", + " root_dir=args.root_dir,\n", + " download=args.download,\n", + " split_scheme=args.split_scheme)\n", + "\n", + "# To implement data augmentation (i.e., have different transforms\n", + "# at training time vs. test time), modify these two lines:\n", + "train_transform = dataset_attributes[args.dataset]['transform'](args.model)\n", + "eval_transform = dataset_attributes[args.dataset]['transform'](args.model)\n", + "\n", + "train_grouper = CombinatorialGrouper(\n", + " dataset=full_dataset,\n", + " groupby_fields=args.groupby_fields)\n", + "\n", + "datasets = defaultdict(dict)\n", + "for split in full_dataset.split_dict.keys():\n", + " if split=='train':\n", + " transform = train_transform\n", + " verbose = True\n", + " elif split == 'val':\n", + " transform = eval_transform\n", + " verbose = True\n", + " else:\n", + " transform = eval_transform\n", + " verbose = False\n", + " # Get subset\n", + " datasets[split]['dataset'] = full_dataset.get_subset(\n", + " split,\n", + " frac=args.frac,\n", + " transform=transform)\n", + "\n", + " # Get loader\n", + " shared_loader_kwargs = {\n", + " 'num_workers': 4,\n", + " 'pin_memory': True,\n", + " 'batch_size': args.batch_size,\n", + " 'collate_fn': dataset_attributes[args.dataset]['collate']\n", + " }\n", + "\n", + " if split == 'train':\n", + " datasets[split]['loader'] = get_train_loader(\n", + " loader=args.train_loader,\n", + " dataset=datasets[split]['dataset'],\n", + " grouper=train_grouper,\n", + " train_loader_kwargs=args.train_loader_kwargs,\n", + " **shared_loader_kwargs)\n", + " else:\n", + " datasets[split]['loader'] = get_eval_loader(\n", + " loader=args.eval_loader,\n", + " dataset=datasets[split]['dataset'],\n", + " grouper=train_grouper,\n", + " **shared_loader_kwargs)\n", + "\n", + " # Set fields\n", + " datasets[split]['split'] = split\n", + " datasets[split]['name'] = full_dataset.split_names[split]\n", + " datasets[split]['verbose'] = verbose\n", + " # Loggers\n", + " datasets[split]['eval_logger'] = CSVBatchLogger(\n", + " os.path.join(args.log_dir, f'{split}_eval.csv'), mode=mode)\n", + " datasets[split]['algo_logger'] = CSVBatchLogger(\n", + " os.path.join(args.log_dir, f'{split}_algo.csv'), mode=mode)\n", + "\n", + "# Logging dataset info\n", + "if args.no_group_logging and full_dataset.is_classification and full_dataset.y_size==1:\n", + " log_grouper = CombinatorialGrouper(\n", + " dataset=full_dataset,\n", + " groupby_fields=['y'])\n", + "elif args.no_group_logging:\n", + " log_grouper = None\n", + "else:\n", + " log_grouper = train_grouper\n", + "log_group_data(args, datasets, log_grouper, logger)\n", + "\n", + "## Initialize algorithm\n", + "algorithm = initialize_algorithm(args, datasets, train_grouper)\n", + "\n", + "## Load saved results if resuming\n", + "if resume:\n", + " save_path = os.path.join(args.log_dir, 'last_model.pth')\n", + " prev_epoch, best_val_metric = load(algorithm, save_path)\n", + " epoch_offset = prev_epoch + 1\n", + "else:\n", + " epoch_offset=0\n", + " best_val_metric=None\n", + "\n", + "train(algorithm,\n", + " datasets,\n", + " logger,\n", + " args,\n", + " epoch_offset=epoch_offset,\n", + " best_val_metric=best_val_metric)\n", + "\n", + "logger.close()\n", + "for split in datasets:\n", + " datasets[split]['eval_logger'].close()\n", + " datasets[split]['algo_logger'].close()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 2", + "language": "python", + "name": "python2" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 4a733f331f220c7e5a72aada349800e1af9a1823 Mon Sep 17 00:00:00 2001 From: aikanor Date: Tue, 2 Feb 2021 11:38:01 -0800 Subject: [PATCH 003/244] model tweak --- examples/models/CNN_genome.py | 151 +--------------------------------- sandbox_model.ipynb | 32 +++++++ 2 files changed, 34 insertions(+), 149 deletions(-) diff --git a/examples/models/CNN_genome.py b/examples/models/CNN_genome.py index f0115322..75295cd3 100644 --- a/examples/models/CNN_genome.py +++ b/examples/models/CNN_genome.py @@ -16,7 +16,7 @@ class Beagle(nn.Module): TODO: Finish docstring. """ - def __init__(self, args): + def __init__(self): """ Parameters ---------- @@ -25,7 +25,7 @@ def __init__(self, args): """ super(Beagle, self).__init__() - self.dropout = args.dropout + self.dropout = 0.3 self.num_cell_types = 1 self.conv1 = nn.Conv2d(5, 300, (19, 1), stride = (1, 1), padding=(9,0)) self.conv2 = nn.Conv2d(300, 200, (11, 1), stride = (1, 1), padding = (5,0)) @@ -60,150 +60,3 @@ def forward(self, s): s = self.fc3(s) return s, conv_out - - - -#class MLP(nn.Module): -# """Just an MLP""" -# def __init__(self, n_inputs, n_outputs, width, depth, drop_out): -# super(MLP, self).__init__() -# -# self.input = nn.Linear(n_inputs, width) -# self.dropout = nn.Dropout(dropout) -# self.hiddens = nn.ModuleList([ -# nn.Linear(width,width) -# for _ in range(depth-2)]) -# self.output = nn.Linear(width, n_outputs) -# self.n_outputs = n_outputs -# -# def forward(self, x): -# x = self.input(x) -# x = self.dropout(x) -# x = F.relu(x) -# for hidden in self.hiddens: -# x = hidden(x) -# x = self.dropout(x) -# x = F.relu(x) -# x = self.output(x) -# return x - - -""" -DeepSEA architecture (Zhou & Troyanskaya, 2015). -Based on https://github.com/FunctionLab/selene/blob/master/models/deepsea.py -""" - -class DeepSEA(nn.Module): - def __init__(self, sequence_length, n_genomic_features): - """ - Parameters - ---------- - sequence_length : int - n_genomic_features : int - """ - super(DeepSEA, self).__init__() - conv_kernel_size = 8 - pool_kernel_size = 4 - - self.conv_net = nn.Sequential( - nn.Conv1d(4, 320, kernel_size=conv_kernel_size), - nn.ReLU(inplace=True), - nn.MaxPool1d( - kernel_size=pool_kernel_size, stride=pool_kernel_size), - nn.Dropout(p=0.2), - - nn.Conv1d(320, 480, kernel_size=conv_kernel_size), - nn.ReLU(inplace=True), - nn.MaxPool1d( - kernel_size=pool_kernel_size, stride=pool_kernel_size), - nn.Dropout(p=0.2), - - nn.Conv1d(480, 960, kernel_size=conv_kernel_size), - nn.ReLU(inplace=True), - nn.Dropout(p=0.5)) - - reduce_by = conv_kernel_size - 1 - pool_kernel_size = float(pool_kernel_size) - self.n_channels = int( - np.floor( - (np.floor( - (sequence_length - reduce_by) / pool_kernel_size) - - reduce_by) / pool_kernel_size) - - reduce_by) - self.classifier = nn.Sequential( - nn.Linear(960 * self.n_channels, n_genomic_features), - nn.ReLU(inplace=True), - nn.Linear(n_genomic_features, n_genomic_features), - nn.Sigmoid()) - - def forward(self, x): - """Forward propagation of a batch. - """ - out = self.conv_net(x) - reshape_out = out.view(out.size(0), 960 * self.n_channels) - predict = self.classifier(reshape_out) - return predict - -""" -def criterion(): - return nn.BCELoss() - -def get_optimizer(lr): - # The optimizer and the parameters with which to initialize the optimizer. At a later time, we initialize the optimizer by also passing in the model parameters (`model.parameters()`). We cannot initialize the optimizer until the model has been initialized. - return (torch.optim.SGD, {"lr": lr, "weight_decay": 1e-6, "momentum": 0.9}) -""" - - - -""" -DanQ architecture (Quang & Xie, 2016). -""" - -class DanQ(nn.Module): - def __init__(self, sequence_length, n_genomic_features): - """ - Parameters - ---------- - sequence_length : int - Input sequence length - n_genomic_features : int - Total number of features to predict - """ - super(DanQ, self).__init__() - self.nnet = nn.Sequential( - nn.Conv1d(4, 320, kernel_size=26), - nn.ReLU(inplace=True), - nn.MaxPool1d( - kernel_size=13, stride=13), - nn.Dropout(0.2)) - - self.bdlstm = nn.Sequential(nn.LSTM(320, 320, num_layers=1, batch_first=True, bidirectional=True)) - - self._n_channels = math.floor( - (sequence_length - 25) / 13) - self.classifier = nn.Sequential( - nn.Dropout(0.5), - nn.Linear(self._n_channels * 640, 925), - nn.ReLU(inplace=True), - nn.Linear(925, n_genomic_features), - nn.Sigmoid()) - - def forward(self, x): - """Forward propagation of a batch. - """ - out = self.nnet(x) - reshape_out = out.transpose(0, 1).transpose(0, 2) - out, _ = self.bdlstm(reshape_out) - out = out.transpose(0, 1) - reshape_out = out.contiguous().view( - out.size(0), 640 * self._n_channels) - predict = self.classifier(reshape_out) - return predict - -""" -def criterion(): - return nn.BCELoss() - -def get_optimizer(lr): - return (torch.optim.RMSprop, {"lr": lr}) -""" \ No newline at end of file diff --git a/sandbox_model.ipynb b/sandbox_model.ipynb index c264d747..885c8a59 100644 --- a/sandbox_model.ipynb +++ b/sandbox_model.ipynb @@ -288,6 +288,38 @@ "_dnase_allcelltypes" ] }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "from examples.models import CNN_genome" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "ename": "TypeError", + "evalue": "unbound method parameters() must be called with Beagle instance as first argument (got nothing instead)", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# def count_parameters(model):\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;31m# return sum(p.numel() for p in model.parameters() if p.requires_grad)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mCNN_genome\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mBeagle\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparameters\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m: unbound method parameters() must be called with Beagle instance as first argument (got nothing instead)" + ] + } + ], + "source": [ + "# def count_parameters(model):\n", + "# return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", + "CNN_genome.Beagle.parameters()" + ] + }, { "cell_type": "code", "execution_count": 14, From 99c7fe65b8bc36d826bba6bf2c231808e413fd71 Mon Sep 17 00:00:00 2001 From: aikanor Date: Sun, 7 Feb 2021 16:24:31 -0800 Subject: [PATCH 004/244] preprocessing changes --- .../sandbox_data-checkpoint.ipynb | 952 ------------------ dataset_preprocessing/encode-tfbs/README.md | 18 + .../encode-tfbs/prep_accessibility.py | 41 + .../encode-tfbs/prep_datasets.ipynb | 279 +++++ .../encode-tfbs/prep_sequence.py | 151 +++ .../sandbox_data.ipynb => sandbox_data.ipynb | 77 +- sandbox_model.ipynb | 486 +++++---- wilds/datasets/camelyon17_dataset.py | 282 +++--- 8 files changed, 953 insertions(+), 1333 deletions(-) delete mode 100644 dataset_preprocessing/encode-tfbs/.ipynb_checkpoints/sandbox_data-checkpoint.ipynb create mode 100644 dataset_preprocessing/encode-tfbs/README.md create mode 100644 dataset_preprocessing/encode-tfbs/prep_accessibility.py create mode 100644 dataset_preprocessing/encode-tfbs/prep_datasets.ipynb create mode 100644 dataset_preprocessing/encode-tfbs/prep_sequence.py rename dataset_preprocessing/encode-tfbs/sandbox_data.ipynb => sandbox_data.ipynb (97%) diff --git a/dataset_preprocessing/encode-tfbs/.ipynb_checkpoints/sandbox_data-checkpoint.ipynb b/dataset_preprocessing/encode-tfbs/.ipynb_checkpoints/sandbox_data-checkpoint.ipynb deleted file mode 100644 index b2e74829..00000000 --- a/dataset_preprocessing/encode-tfbs/.ipynb_checkpoints/sandbox_data-checkpoint.ipynb +++ /dev/null @@ -1,952 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Initialize dataset object" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "57.5289368629\n", - "65.2459537983\n" - ] - } - ], - "source": [ - "import numpy as np, pandas as pd, os, time\n", - "import torch, torchvision\n", - "\n", - "data_dir = '/oak/stanford/groups/akundaje/abalsubr/DREAM/wilds/codalab_archive/'\n", - "tf = 'MAX'\n", - "itime = time.time()\n", - "train_chr = pd.read_csv(os.path.join(data_dir, 'labels/{}.train.labels.tsv.gz'.format(tf)), sep='\\t')\n", - "print(time.time() - itime)\n", - "val_chr = pd.read_csv(os.path.join(data_dir, 'labels/{}.val.labels.tsv.gz'.format(tf)), sep='\\t')\n", - "print(time.time() - itime)" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']\n", - "val_celltype = ['A549']\n", - "test_celltype = ['GM12878']\n", - "all_celltypes = train_celltypes + val_celltype + test_celltype\n", - "\n", - "metadata_map = {}\n", - "metadata_map['chr'] = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX']\n", - "metadata_map['celltype'] = all_celltypes\n", - "\n", - "_split_dict = {\n", - " 'train': 0,\n", - " 'val-id': 1,\n", - " 'test': 2,\n", - " 'val-ood': 3\n", - "}\n", - "_split_names = {\n", - " 'train': 'Train',\n", - " 'val-id': 'Validation (ID)',\n", - " 'test': 'Test',\n", - " 'val-ood': 'Validation (OOD)',\n", - "}\n", - "_split_scheme = 'standard'" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "0.0467748641968\n", - "('chr1', 4.52302885055542)\n", - "('chr2', 8.645489931106567)\n", - "('chr3', 11.959153890609741)\n", - "('chr4', 15.15813684463501)\n", - "('chr5', 18.22238802909851)\n", - "('chr6', 21.19420099258423)\n", - "('chr7', 23.940655946731567)\n", - "('chr8', 26.415233850479126)\n", - "('chr9', 28.833614826202393)\n", - "('chr10', 31.08920383453369)\n", - "('chr11', 33.37020301818848)\n", - "('chr12', 35.98973989486694)\n", - "('chr13', 37.88540601730347)\n", - "('chr14', 39.68082284927368)\n", - "('chr15', 41.242313861846924)\n", - "('chr16', 42.74874496459961)\n", - "('chr17', 44.12280797958374)\n", - "('chr18', 45.46893382072449)\n", - "('chr19', 46.50577902793884)\n", - "('chr20', 47.59563183784485)\n", - "('chr21', 48.31779384613037)\n", - "('chr22', 49.17265295982361)\n", - "('chrX', 51.75806999206543)\n", - "('H1-hESC', 25.880441904067993)\n", - "('HCT116', 50.130937814712524)\n", - "('HeLa-S3', 75.29559993743896)\n", - "('HepG2', 102.25979495048523)\n", - "('K562', 128.43050694465637)\n", - "('A549', 154.80679488182068)\n", - "('GM12878', 182.0279529094696)\n" - ] - } - ], - "source": [ - "itime = time.time()\n", - "sequence_filename = os.path.join(data_dir, 'sequence.npz')\n", - "seq_arr = np.load(sequence_filename)\n", - "print(time.time() - itime)\n", - "\n", - "itime = time.time()\n", - "_seq_bp = {}\n", - "for chrom in seq_arr:\n", - " _seq_bp[chrom] = seq_arr[chrom]\n", - " print(chrom, time.time() - itime)\n", - "\n", - "itime = time.time()\n", - "_dnase_allcelltypes = {}\n", - "for ct in all_celltypes:\n", - " dnase_filename = os.path.join(data_dir, '{}_dnase.npz'.format(ct))\n", - " dnase_npz_file = np.load(dnase_filename)\n", - " _dnase_allcelltypes[ct] = {}\n", - " for chrom in _seq_bp:\n", - " _dnase_allcelltypes[ct][chrom] = dnase_npz_file[chrom]\n", - " print(ct, time.time() - itime)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "ename": "NameError", - "evalue": "name 'all_df' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# len(_dnase_allcelltypes)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mall_df\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;31mNameError\u001b[0m: name 'all_df' is not defined" - ] - } - ], - "source": [ - "# len(_dnase_allcelltypes)\n", - "all_df" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "ename": "AttributeError", - "evalue": "'module' object has no attribute 'isin'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mtr_chrs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m'chr2'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr9'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr11'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mte_chrs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m'chr1'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr8'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr21'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mtraining_df\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'chr'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtr_chrs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mval_df\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mval_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mval_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'chr'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mte_chrs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mall_df\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtraining_df\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_df\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mAttributeError\u001b[0m: 'module' object has no attribute 'isin'" - ] - } - ], - "source": [ - "tr_chrs = ['chr2', 'chr9', 'chr11']\n", - "te_chrs = ['chr1', 'chr8', 'chr21']\n", - "training_df = train_chr[np.isin(train_chr['chr'], tr_chrs)]\n", - "val_df = val_chr[np.isin(val_chr['chr'], te_chrs)]\n", - "all_df = pd.concat([training_df, val_df])\n", - "\n", - "#filter_msk = all_df['start'] >= 0\n", - "filter_msk = all_df['start']%1000 == 0\n", - "all_df = all_df[filter_msk]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "itime = time.time()\n", - "pd_list = []\n", - "for ct in all_celltypes:\n", - " tc_chr = all_df[['chr', 'start', 'stop', ct]]\n", - " tc_chr.columns = ['chr', 'start', 'stop', 'y']\n", - " tc_chr['celltype'] = ct\n", - " pd_list.append(tc_chr)\n", - "metadata_df = pd.concat(pd_list)\n", - "print(time.time() - itime)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "itime = time.time()\n", - "y_array = metadata_df['y'].replace({'U': 0, 'B': 1, 'A': -1}).values\n", - "non_ambig_mask = (y_array != -1)\n", - "metadata_df['y'] = y_array\n", - "_metadata_df = metadata_df[non_ambig_mask]\n", - "_y_array = torch.LongTensor(y_array[non_ambig_mask])\n", - "print(time.time() - itime)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "itime = time.time()\n", - "chr_ints = _metadata_df['chr'].replace(dict( [(y, x) for x, y in enumerate(metadata_map['chr'])] )).values\n", - "celltype_ints = _metadata_df['celltype'].replace(dict( [(y, x) for x, y in enumerate(metadata_map['celltype'])] )).values\n", - "print(time.time() - itime)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "train_chr_mask = np.isin(_metadata_df['chr'], tr_chrs)\n", - "val_chr_mask = np.isin(_metadata_df['chr'], te_chrs)\n", - "train_celltype_mask = np.isin(_metadata_df['celltype'], train_celltypes)\n", - "val_celltype_mask = np.isin(_metadata_df['celltype'], val_celltype)\n", - "test_celltype_mask = np.isin(_metadata_df['celltype'], test_celltype)\n", - "\n", - "split_array = -1*np.ones(_metadata_df.shape[0]).astype(int)\n", - "split_array[np.logical_and(train_chr_mask, train_celltype_mask)] = _split_dict['train']\n", - "split_array[np.logical_and(val_chr_mask, test_celltype_mask)] = _split_dict['test']\n", - "split_array[np.logical_and(val_chr_mask, val_celltype_mask)] = _split_dict['val-ood']\n", - "split_array[np.logical_and(val_chr_mask, train_celltype_mask)] = _split_dict['val-id']\n", - "_metadata_df['split'] = split_array\n", - "_split_array = split_array" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "ename": "ImportError", - "evalue": "No module named data", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mDataLoader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mdataset_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;31mImportError\u001b[0m: No module named data" - ] - } - ], - "source": [ - "from torch.utils.data import DataLoader\n", - "from data import dataset_attributes" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [], - "source": [ - "from PIL import Image\n", - "import argparse\n", - "class ParseKwargs(argparse.Action):\n", - " def __call__(self, parser, namespace, values, option_string=None):\n", - " setattr(namespace, self.dest, dict())\n", - " for value in values:\n", - " key, value_str = value.split('=')\n", - " if value_str.replace('-','').isnumeric():\n", - " processed_val = int(value_str)\n", - " elif value_str.replace('-','').replace('.','').isnumeric():\n", - " processed_val = float(value_str)\n", - " elif value_str in ['True', 'true']:\n", - " processed_val = True\n", - " elif value_str in ['False', 'false']:\n", - " processed_val = False\n", - " else:\n", - " processed_val = value_str\n", - " getattr(namespace, self.dest)[key] = processed_val" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [ - { - "ename": "NameError", - "evalue": "name 'algorithm_constructors' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[0;31m# Algorithm and objective\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 33\u001b[0;31m \u001b[0mparser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_argument\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'--algorithm'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrequired\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mchoices\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0malgorithm_constructors\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 34\u001b[0m \u001b[0mparser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_argument\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'--algorithm_kwargs'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnargs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'*'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maction\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mParseKwargs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdefault\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 35\u001b[0m \u001b[0mparser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_argument\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'--groupby_fields'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnargs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'+'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdefault\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mNameError\u001b[0m: name 'algorithm_constructors' is not defined" - ] - } - ], - "source": [ - "ROOTDIR = '/oak/stanford/groups/akundaje/abalsubr/wilds_other'\n", - "args_kw = \"-d camelyon17 --algorithm ERM --model densenet121 --split_scheme standard --groupby_fields hospital --loss_function cross_entropy --optimizer SGD --lr 0.0001 --batch_size 32 --weight_decay 0 --n_epochs 10 --scheduler ReduceLROnPlateau --scheduler_metric_split val --scheduler_metric_name acc_avg --log_dir log --log_every 50 --save_step 1000 --save_best --save_last --seed 0 --evaluate_all_splits --root_dir {}\".format(\n", - " ROOTDIR).split()\n", - "\n", - "parser = argparse.ArgumentParser()\n", - "\n", - "# Dataset\n", - "parser.add_argument('-d', '--dataset', choices=['encodeTFBS', 'amazon', 'camelyon17', 'celebA', 'civilcomments', 'iwildcam', 'waterbirds', 'yelp', 'poverty', 'fmow', 'ogbg-molpcba'], required=True)\n", - "parser.add_argument('--split_scheme', default='standard',\n", - " help='Identifies how the train/val/test split is constructed. Choices are dataset-specific.')\n", - "parser.add_argument('--dataset_kwargs', nargs='*', action=ParseKwargs, default={})\n", - "parser.add_argument('--root_dir', default=None, required=True,\n", - " help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).')\n", - "parser.add_argument('--download', default=False, action='store_true',\n", - " help='If true, tries to downloads the dataset if it does not exist in root_dir.')\n", - "parser.add_argument('--frac', type=float, default=1.0,\n", - " help='Convenience parameter that scales all dataset splits down to the specified fraction, for development purposes.')\n", - "\n", - "# Loaders\n", - "parser.add_argument('--train_loader', choices=['standard', 'group'], default='standard')\n", - "parser.add_argument('--train_loader_kwargs', nargs='*', action=ParseKwargs, default={})\n", - "parser.add_argument('--eval_loader', choices=['standard', 'group'], default='standard')\n", - "parser.add_argument('--batch_size', type=int, default=32)\n", - "parser.add_argument('--no_pin_memory', action='store_true') # TODO: put as loader_kwargs\n", - "parser.add_argument('--num_workers', type=int, default=4) # TODO: put as loader kwargs\n", - "\n", - "# Model\n", - "parser.add_argument(\n", - " '--model',\n", - " choices=['bert-base-uncased', 'inception_v3', 'densenet121', 'wideresnet50', 'resnet50', 'gin-virtual', 'resnet18_ms'],\n", - " default='resnet50')\n", - "parser.add_argument('--model_kwargs', nargs='*', action=ParseKwargs, default={},\n", - " help='keyword arguments for model initialization passed as key1=value1 key2=value2')\n", - "parser.add_argument('--train_from_scratch', action='store_true', default=False)\n", - "\n", - "# Algorithm and objective\n", - "parser.add_argument('--algorithm', required=True, choices=algorithm_constructors.keys())\n", - "parser.add_argument('--algorithm_kwargs', nargs='*', action=ParseKwargs, default={})\n", - "parser.add_argument('--groupby_fields', nargs='+', default=None)\n", - "parser.add_argument('--loss_function', required=True, choices = losses.keys()) #TODO: make default\n", - "parser.add_argument('--val_metric', default=None)\n", - "\n", - "# Optimization\n", - "parser.add_argument('--n_epochs', type=int, default=4)\n", - "parser.add_argument('--optimizer', default=None, choices=optimizer_attributes.keys())\n", - "parser.add_argument('--lr', type=float, required=True)\n", - "parser.add_argument('--weight_decay', type=float, required=True)\n", - "parser.add_argument('--optimizer_kwargs', nargs='*', action=ParseKwargs, default={})\n", - "parser.add_argument('--scheduler', default=None, choices=scheduler_attributes.keys())\n", - "parser.add_argument('--scheduler_kwargs', nargs='*', action=ParseKwargs, default={})\n", - "parser.add_argument('--scheduler_metric_split', choices=['train', 'val'], default='val')\n", - "parser.add_argument('--scheduler_metric_name')\n", - "\n", - "# Evaluation\n", - "parser.add_argument('--evaluate_all_splits', action='store_true', default=False)\n", - "parser.add_argument('--additional_eval_splits', nargs='+', default=[])\n", - "\n", - "# Misc\n", - "parser.add_argument('--device', type=int, default=0)\n", - "parser.add_argument('--seed', type=int, default=0)\n", - "parser.add_argument('--log_dir', default='./logs')\n", - "parser.add_argument('--log_every', default=50, type=int)\n", - "parser.add_argument('--save_step', type=int, default=None)\n", - "parser.add_argument('--save_best', action='store_true', default=False)\n", - "parser.add_argument('--save_last', action='store_true', default=False)\n", - "parser.add_argument('--save_outputs', action='store_true', default=False)\n", - "parser.add_argument('--no_group_logging', action='store_true', default=False)\n", - "parser.add_argument('--val_metric_decreasing', action='store_true', default=False)\n", - "parser.add_argument('--use_wandb', action='store_true', default=False)\n", - "parser.add_argument('--progress_bar', action='store_true', default=False)\n", - "parser.add_argument('--resume', default=False, action='store_true')\n", - "parser.add_argument('--eval_only', default=False, action='store_true')\n", - "\n", - "args = parser.parse_args(args_kw)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# get_input (idx)" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": {}, - "outputs": [ - { - "ename": "NameError", - "evalue": "name '_metadata_df' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0midx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mthis_metadata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_metadata_df\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miloc\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mitime\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mflank_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m400\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mNameError\u001b[0m: name '_metadata_df' is not defined" - ] - } - ], - "source": [ - "idx = 3\n", - "this_metadata = _metadata_df.iloc[idx, :]\n", - "\n", - "itime = time.time()\n", - "flank_size = 400\n", - "interval_start = this_metadata['start'] - flank_size\n", - "interval_end = this_metadata['stop'] + flank_size\n", - "dnase_this = _dnase_allcelltypes[this_metadata['celltype']][this_metadata['chr']][interval_start:interval_end]\n", - "seq_this = _seq_bp[this_metadata['chr']][interval_start:interval_end]\n", - "data = np.column_stack([seq_this, dnase_this])\n", - "# print(time.time() - itime)" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "0.028102874755859375\n" - ] - } - ], - "source": [ - "itime = time.time()\n", - "metadata_array = torch.stack(\n", - " (torch.LongTensor(chr_ints), \n", - " torch.LongTensor(celltype_ints), \n", - " _y_array),\n", - " dim=1)\n", - "print(time.time() - itime)" - ] - }, - { - "cell_type": "code", - "execution_count": 34, - "metadata": {}, - "outputs": [ - { - "ename": "ModuleNotFoundError", - "evalue": "No module named 'torch_scatter'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m#data.shape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mwilds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata_loaders\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mget_train_loader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mget_eval_loader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m~/dr_benchmark/wilds/common/data_loaders.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mDataLoader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msampler\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mWeightedRandomSampler\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mSubsetRandomSampler\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mwilds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mget_counts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msplit_into_groups\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mIPython\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/dr_benchmark/wilds/common/utils.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch_scatter\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnumpy\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mSubset\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mpandas\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapi\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtypes\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mCategoricalDtype\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'torch_scatter'" - ] - } - ], - "source": [ - "#data.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 157, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "4600" - ] - }, - "execution_count": 157, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "data.shape\n", - "interval_end\n", - "# itime = time.time()\n", - "# np.save(os.path.join(data_dir, 'stmp.npy'), sa)\n", - "# print(time.time() - itime)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Run training experiment" - ] - }, - { - "cell_type": "code", - "execution_count": 167, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'python3 examples/run_expt.py -d encodeTFBS --algorithm ERM --model densenet121 --split_scheme standard --groupby_fields hospital --loss_function cross_entropy --optimizer SGD --lr 0.0001 --batch_size 32 --weight_decay 0 --n_epochs 10 --scheduler ReduceLROnPlateau --scheduler_metric_split val --scheduler_metric_name acc_avg --log_dir log --log_every 50 --save_step 1000 --save_best --save_last --seed 0 --evaluate_all_splits --root_dir ROOTDIR'" - ] - }, - "execution_count": 167, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "cmdstr = \"python3 examples/run_expt.py -d encodeTFBS --algorithm ERM --model densenet121 --split_scheme standard --groupby_fields hospital --loss_function cross_entropy\"\n", - "cmdstr += \" \"\n", - "cmdstr += \"--optimizer SGD --lr 0.0001 --batch_size 32 --weight_decay 0 --n_epochs 10 --scheduler ReduceLROnPlateau --scheduler_metric_split val --scheduler_metric_name acc_avg\"\n", - "cmdstr += \" \"\n", - "cmdstr += \"--log_dir log --log_every 50 --save_step 1000 --save_best --save_last --seed 0 --evaluate_all_splits --root_dir ROOTDIR\"\n", - "cmdstr" - ] - }, - { - "cell_type": "code", - "execution_count": 164, - "metadata": {}, - "outputs": [ - { - "ename": "NameError", - "evalue": "name '_metadata_array' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0m_metadata_array\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;31mNameError\u001b[0m: name '_metadata_array' is not defined" - ] - } - ], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 165, - "metadata": {}, - "outputs": [ - { - "ename": "ModuleNotFoundError", - "evalue": "No module named 'torch_scatter'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minsert\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'..'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 16\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mwilds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata_loaders\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mget_train_loader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mget_eval_loader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 17\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mwilds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrouper\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mCombinatorialGrouper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mwilds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mget_counts\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/dr_benchmark/wilds/common/data_loaders.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mDataLoader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msampler\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mWeightedRandomSampler\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mSubsetRandomSampler\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mwilds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mget_counts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msplit_into_groups\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mIPython\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/dr_benchmark/wilds/common/utils.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch_scatter\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnumpy\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mSubset\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mpandas\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapi\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtypes\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mCategoricalDtype\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'torch_scatter'" - ] - } - ], - "source": [ - "import os, csv\n", - "import time\n", - "import argparse\n", - "import IPython\n", - "import pandas as pd\n", - "import torch\n", - "import torch.nn as nn\n", - "import torchvision\n", - "import sys\n", - "from collections import defaultdict\n", - "# torch.multiprocessing.set_sharing_strategy('file_system')\n", - "\n", - "# TODO: Replace this once we make wilds into an installed package\n", - "sys.path.insert(1, os.path.join(sys.path[0], '..'))\n", - "\n", - "from wilds.common.data_loaders import get_train_loader, get_eval_loader\n", - "from wilds.common.grouper import CombinatorialGrouper\n", - "from wilds.common.utils import get_counts\n", - "\n", - "from models.model_attributes import model_attributes\n", - "from utils import set_seed, Logger, BatchLogger, log_args, ParseKwargs, load\n", - "from train import train, evaluate\n", - "from data import dataset_attributes\n", - "from optimizer import optimizer_attributes\n", - "from scheduler import scheduler_attributes\n", - "from loss import losses\n", - "from utils import log_group_data\n", - "from algorithms.constructors import algorithm_constructors" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "from examples.models.model_attributes import model_attributes" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "ename": "ModuleNotFoundError", - "evalue": "No module named 'utils'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel_attributes\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mmodel_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mset_seed\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mLogger\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mCSVBatchLogger\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlog_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mParseKwargs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mload\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 21\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 22\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mdataset_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0moptimizer_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/dr_benchmark/examples/train.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0msave\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mIPython\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'utils'" - ] - } - ], - "source": [ - "def initialize_algorithm(args, datasets, train_grouper):\n", - " train_dataset = datasets['train']['dataset']\n", - " train_loader = datasets['train']['loader']\n", - "\n", - " # Configure the final layer of the networks used\n", - " # The code below are defaults. Edit this if you need special config for your model.\n", - " if (train_dataset.is_classification) and (train_dataset.y_size == 1):\n", - " # For single-task classification, we have one output per class\n", - " d_out = train_dataset.n_classes\n", - " elif (train_dataset.is_classification) and (train_dataset.y_size > 1) and (train_dataset.n_classes == 2):\n", - " # For multi-task binary classification (each output is the logit for each binary class)\n", - " d_out = train_dataset.y_size\n", - " elif (not train_dataset.is_classification):\n", - " # For regression, we have one output per target dimension\n", - " d_out = train_dataset.y_size\n", - " else:\n", - " raise RuntimeError('d_out not defined.')\n", - " \n", - "\n", - " # Sanity checking input args\n", - " if args.algorithm == 'groupDRO':\n", - " assert args.train_loader_kwargs['uniform_over_groups']\n", - " elif args.algorithm in ['deepCORAL', 'IRM']:\n", - " assert args.train_loader == 'group'\n", - " assert args.train_loader_kwargs['uniform_over_groups']\n", - " assert args.train_loader_kwargs['distinct_groups']\n", - "\n", - " # Other config\n", - " n_train_steps = len(train_loader) * args.n_epochs\n", - "# prediction_fn = dataset_attributes[args.dataset]['prediction_fn']\n", - " loss = losses[args.loss_function]\n", - " metric = dataset_attributes[args.dataset]['metric']\n", - " train_g = train_grouper.metadata_to_group(train_dataset.metadata_array)\n", - " is_group_in_train = get_counts(train_g, train_grouper.n_groups) > 0\n", - " algorithm_constructor = algorithm_constructors[args.algorithm]\n", - " algorithm = algorithm_constructor(\n", - " args=args,\n", - " d_out=d_out,\n", - " grouper=train_grouper,\n", - " loss=loss,\n", - " metric=metric,\n", - " n_train_steps=n_train_steps,\n", - " is_group_in_train=is_group_in_train)\n", - " return algorithm" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def main():\n", - " parser = argparse.ArgumentParser()\n", - "\n", - " # Dataset\n", - " parser.add_argument('-d', '--dataset', choices=dataset_attributes.keys(), required=True)\n", - " parser.add_argument('--split_scheme', default='standard',\n", - " help='Identifies how the train/val/test split is constructed. Choices are dataset-specific.')\n", - " parser.add_argument('--dataset_kwargs', nargs='*', action=ParseKwargs, default={})\n", - " parser.add_argument('--root_dir', default=None, required=True,\n", - " help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).')\n", - " parser.add_argument('--download', default=False, action='store_true',\n", - " help='If true, tries to downloads the dataset if it does not exist in root_dir.')\n", - " parser.add_argument('--frac', type=float, default=1.0,\n", - " help='Convenience parameter that scales all dataset splits down to the specified fraction, for development purposes.')\n", - "\n", - " # Loaders\n", - " parser.add_argument('--train_loader', choices=['standard', 'group'], default='standard')\n", - " parser.add_argument('--train_loader_kwargs', nargs='*', action=ParseKwargs, default={})\n", - " parser.add_argument('--eval_loader', choices=['standard', 'group'], default='standard')\n", - " parser.add_argument('--batch_size', type=int, default=32)\n", - " parser.add_argument('--no_pin_memory', action='store_true') # TODO: put as loader_kwargs\n", - " parser.add_argument('--num_workers', type=int, default=4) # TODO: put as loader kwargs\n", - "\n", - " # Model\n", - " parser.add_argument(\n", - " '--model',\n", - " choices=model_attributes.keys(),\n", - " default='resnet50')\n", - " parser.add_argument('--model_kwargs', nargs='*', action=ParseKwargs, default={},\n", - " help='keyword arguments for model initialization passed as key1=value1 key2=value2')\n", - " parser.add_argument('--train_from_scratch', action='store_true', default=False)\n", - "\n", - " # Algorithm and objective\n", - " parser.add_argument('--algorithm', required=True, choices=algorithm_constructors.keys())\n", - " parser.add_argument('--algorithm_kwargs', nargs='*', action=ParseKwargs, default={})\n", - " parser.add_argument('--groupby_fields', nargs='+', default=None)\n", - " parser.add_argument('--loss_function', required=True, choices = losses.keys()) #TODO: make default\n", - " parser.add_argument('--val_metric', default=None)\n", - "\n", - " # Optimization\n", - " parser.add_argument('--n_epochs', type=int, default=4)\n", - " parser.add_argument('--optimizer', default=None, choices=optimizer_attributes.keys())\n", - " parser.add_argument('--lr', type=float, required=True)\n", - " parser.add_argument('--weight_decay', type=float, required=True)\n", - " parser.add_argument('--optimizer_kwargs', nargs='*', action=ParseKwargs, default={})\n", - " parser.add_argument('--scheduler', default=None, choices=scheduler_attributes.keys())\n", - " parser.add_argument('--scheduler_kwargs', nargs='*', action=ParseKwargs, default={})\n", - " parser.add_argument('--scheduler_metric_split', choices=['train', 'val'], default='val')\n", - " parser.add_argument('--scheduler_metric_name')\n", - "\n", - " # Evaluation\n", - " parser.add_argument('--evaluate_all_splits', action='store_true', default=False)\n", - " parser.add_argument('--additional_eval_splits', nargs='+', default=[])\n", - "\n", - " # Misc\n", - " parser.add_argument('--device', type=int, default=0)\n", - " parser.add_argument('--seed', type=int, default=0)\n", - " parser.add_argument('--log_dir', default='./logs')\n", - " parser.add_argument('--log_every', default=50, type=int)\n", - " parser.add_argument('--save_step', type=int, default=None)\n", - " parser.add_argument('--save_best', action='store_true', default=False)\n", - " parser.add_argument('--save_last', action='store_true', default=False)\n", - " parser.add_argument('--save_outputs', action='store_true', default=False)\n", - " parser.add_argument('--no_group_logging', action='store_true', default=False)\n", - " parser.add_argument('--val_metric_decreasing', action='store_true', default=False)\n", - " parser.add_argument('--use_wandb', action='store_true', default=False)\n", - " parser.add_argument('--progress_bar', action='store_true', default=False)\n", - " parser.add_argument('--resume', default=False, action='store_true')\n", - " parser.add_argument('--eval_only', default=False, action='store_true')\n", - "\n", - " args = parser.parse_args()\n", - "\n", - " # set device\n", - " args.device = torch.device(\"cuda:\" + str(args.device)) if torch.cuda.is_available() else torch.device(\"cpu\")\n", - "\n", - " # Set defaults\n", - " if args.groupby_fields is None:\n", - " args.no_group_logging = True\n", - " if args.val_metric is None:\n", - " args.val_metric = dataset_attributes[args.dataset]['val_metric']\n", - "\n", - " ## Initialize logs\n", - " if os.path.exists(args.log_dir) and args.resume:\n", - " resume=True\n", - " mode='a'\n", - " else:\n", - " resume=False\n", - " mode='w'\n", - " if not os.path.exists(args.log_dir):\n", - " os.makedirs(args.log_dir)\n", - " logger = Logger(os.path.join(args.log_dir, 'log.txt'), mode)\n", - "\n", - " # Record args\n", - " log_args(args, logger)\n", - "\n", - " # Set random seed\n", - " set_seed(args.seed)\n", - "\n", - " # Data\n", - " full_dataset = dataset_attributes[args.dataset]['constructor'](\n", - " root_dir=args.root_dir,\n", - " download=args.download,\n", - " split_scheme=args.split_scheme,\n", - " **args.dataset_kwargs)\n", - "\n", - " # To implement data augmentation (i.e., have different transforms\n", - " # at training time vs. test time), modify these two lines:\n", - " train_transform = dataset_attributes[args.dataset]['transform'](args.model)\n", - " if dataset_attributes[args.dataset].get('eval_transform') is None:\n", - " eval_transform = dataset_attributes[args.dataset]['transform'](args.model)\n", - " else:\n", - " eval_transform = dataset_attributes[args.dataset]['eval_transform'](args.model)\n", - "\n", - " train_grouper = CombinatorialGrouper(\n", - " dataset=full_dataset,\n", - " groupby_fields=args.groupby_fields)\n", - "\n", - " datasets = defaultdict(dict)\n", - " for split in full_dataset.split_dict.keys():\n", - " if split=='train':\n", - " transform = train_transform\n", - " verbose = True\n", - " elif split == 'val':\n", - " transform = eval_transform\n", - " verbose = True\n", - " else:\n", - " transform = eval_transform\n", - " verbose = False\n", - " # Get subset\n", - " datasets[split]['dataset'] = full_dataset.get_subset(\n", - " split,\n", - " frac=args.frac,\n", - " transform=transform)\n", - "\n", - " # Get loader\n", - " shared_loader_kwargs = {\n", - " 'num_workers': args.num_workers,\n", - " 'pin_memory': not args.no_pin_memory,\n", - " 'batch_size': args.batch_size,\n", - " 'collate_fn': dataset_attributes[args.dataset]['collate']\n", - " }\n", - "\n", - " if split == 'train':\n", - " datasets[split]['loader'] = get_train_loader(\n", - " loader=args.train_loader,\n", - " dataset=datasets[split]['dataset'],\n", - " grouper=train_grouper,\n", - " train_loader_kwargs=args.train_loader_kwargs,\n", - " **shared_loader_kwargs)\n", - " else:\n", - " datasets[split]['loader'] = get_eval_loader(\n", - " loader=args.eval_loader,\n", - " dataset=datasets[split]['dataset'],\n", - " grouper=train_grouper,\n", - " **shared_loader_kwargs)\n", - "\n", - " # Set fields\n", - " datasets[split]['split'] = split\n", - " datasets[split]['name'] = full_dataset.split_names[split]\n", - " datasets[split]['verbose'] = verbose\n", - " # Loggers\n", - " # Loggers\n", - " datasets[split]['eval_logger'] = BatchLogger(\n", - " os.path.join(args.log_dir, f'{split}_eval.csv'), mode=mode, use_wandb=args.use_wandb)\n", - " datasets[split]['algo_logger'] = BatchLogger(\n", - " os.path.join(args.log_dir, f'{split}_algo.csv'), mode=mode, use_wandb=args.use_wandb)\n", - "\n", - " if args.use_wandb:\n", - " initialize_wandb(args)\n", - "\n", - " # Logging dataset info\n", - " if args.no_group_logging and full_dataset.is_classification and full_dataset.y_size==1:\n", - " log_grouper = CombinatorialGrouper(\n", - " dataset=full_dataset,\n", - " groupby_fields=['y'])\n", - " elif args.no_group_logging:\n", - " log_grouper = None\n", - " else:\n", - " log_grouper = train_grouper\n", - " log_group_data(args, datasets, log_grouper, logger)\n", - "\n", - " ## Initialize algorithm\n", - " algorithm = initialize_algorithm(args, datasets, train_grouper)\n", - "\n", - " if not args.eval_only:\n", - " ## Load saved results if resuming\n", - " resume_success = False\n", - " if resume:\n", - " save_path = os.path.join(args.log_dir, 'last_model.pth')\n", - " if not os.path.exists(save_path):\n", - " epochs = [\n", - " int(file.split('_')[0])\n", - " for file in os.listdir(args.log_dir) if file.endswith('.pth')]\n", - " if len(epochs) > 0:\n", - " latest_epoch = max(epochs)\n", - " save_path = os.path.join(args.log_dir, f'{latest_epoch}_model.pth')\n", - " try:\n", - " prev_epoch, best_val_metric = load(algorithm, save_path)\n", - " epoch_offset = prev_epoch + 1\n", - " logger.write(f'Resuming from epoch {epoch_offset} with best val metric {best_val_metric}')\n", - " resume_success = True\n", - " except FileNotFoundError:\n", - " pass\n", - "\n", - " if resume_success == False:\n", - " epoch_offset=0\n", - " best_val_metric=None\n", - "\n", - "\n", - " train(algorithm,\n", - " datasets,\n", - " logger,\n", - " args,\n", - " epoch_offset=epoch_offset,\n", - " best_val_metric=best_val_metric)\n", - " else:\n", - " best_model_path = os.path.join(args.log_dir, 'best_model.pth')\n", - " best_epoch, best_val_metric = load(algorithm, best_model_path)\n", - " evaluate(algorithm, datasets, best_epoch, logger)\n", - "\n", - " logger.close()\n", - " for split in datasets:\n", - " datasets[split]['eval_logger'].close()\n", - " datasets[split]['algo_logger'].close()\n", - "\n", - "if __name__=='__main__':\n", - " main()\n" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 2", - "language": "python", - "name": "python2" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 2 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.13" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/dataset_preprocessing/encode-tfbs/README.md b/dataset_preprocessing/encode-tfbs/README.md new file mode 100644 index 00000000..0be5fbd6 --- /dev/null +++ b/dataset_preprocessing/encode-tfbs/README.md @@ -0,0 +1,18 @@ +## ENCODE-TFBS-wilds feature generation and preprocessing + +#### Requirements +- pyBigWig + +#### Instructions + +1. Download the human genome sequence (hg19 assembly) in FASTA format from http://hgdownload.cse.ucsc.edu/goldenpath/hg19/bigZips/hg19.fa.gz into `SEQUENCE_PATH`. + +2. Run `python prep_sequence.py --seq_path SEQUENCE_PATH --output_dir OUTPUT_DIR` to write the fasta file found in `SEQUENCE_PATH` to a numpy array archive in `OUTPUT_DIR`. + +3. Download the accessibility data from the challenge. This consists of whole-genome DNase files in bigwig format (*.bw) from https://www.synapse.org/#!Synapse:syn6176233. + +4. Run `python prep_accessibility.py --input_dir INPUT_DIR --output_dir OUTPUT_DIR` to extract the bigwigs into numpy array archives, one per celltype. + +5. Download the labels from the challenge into a label directory created for this purpose: + - The training labels from https://www.synapse.org/#!Synapse:syn7413983 for the relevant transcription factor (e.g. https://www.synapse.org/#!Synapse:syn7415202 for the TF MAX). + - The validation labels from https://www.synapse.org/#!Synapse:syn8441154 for the relevant transcription factor (e.g. https://www.synapse.org/#!Synapse:syn8442103 for the TF MAX). diff --git a/dataset_preprocessing/encode-tfbs/prep_accessibility.py b/dataset_preprocessing/encode-tfbs/prep_accessibility.py new file mode 100644 index 00000000..9033224e --- /dev/null +++ b/dataset_preprocessing/encode-tfbs/prep_accessibility.py @@ -0,0 +1,41 @@ +import numpy, pandas +import pyBigWig + +from tqdm import tqdm + + +def generate_accessibility_archives(input_dir, output_dir): + dnases = {} + celltypes = ['A549', 'GM12878', 'H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562'] + + for ctype in celltypes:#glob.glob('dnase_bigwigs/*'): + itime = time.time() + # ctype = pth.split('/')[1].split('.')[1] + bw = pyBigWig.open("{}/DNASE.{}.fc.signal.bigwig".format(input_dir, ctype)) + chromsizes = bw.chroms() + print(ctype, time.time() - itime) + dn_dict = {} + for chrom in chromsizes: #chr_IDs: + x = bw.values(chrom, 0, chromsizes[chrom], numpy=True) + dn_dict[chrom] = np.nan_to_num(x).astype(np.float16) # half-precision makes things significantly smaller (less time to load) + print(chrom, time.time() - itime) + dnases[ctype] = dn_dict + + for ctype in dnases: + itime = time.time() + dn_dict = dnases[ctype] + + # Save as npz archive + np.savez_compressed('{}/{}_dnase'.format(output_dir, ctype), **dn_dict) + print("Saving npz archive for celltype {}. Time: {}".format(ctype, time.time() - itime)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--input_dir', required=True) + parser.add_argument('--output_dir', required=True) + args = parser.parse_args() + + generate_accessibility_archives( + input_dir=args.input_dir, + output_dir=args.output_dir) \ No newline at end of file diff --git a/dataset_preprocessing/encode-tfbs/prep_datasets.ipynb b/dataset_preprocessing/encode-tfbs/prep_datasets.ipynb new file mode 100644 index 00000000..4b1fdc10 --- /dev/null +++ b/dataset_preprocessing/encode-tfbs/prep_datasets.ipynb @@ -0,0 +1,279 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import prep_utils, scipy, numpy as np, time\n", + "from scipy import sparse\n", + "\n", + "# Human chromosome names\n", + "chr_IDs = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX']" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Sequence" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "62743362it [00:54, 1151676.47it/s]\n" + ] + } + ], + "source": [ + "a = prep_utils.read_fasta('sequence/hg19.genome.fa')\n", + "\n", + "kw_dict = {}\n", + "itime = time.time()\n", + "for chrom in chr_IDs:\n", + " seqstr = a[chrom]\n", + " kw_dict[chrom] = prep_utils.one_hot_encode(seqstr, alphabet=['A', 'C', 'G', 'T', 'N'])\n", + " print(chrom, time.time() - itime)\n", + "\n", + "# Save as npz archive; can take several (>20) minutes\n", + "print(\"Saving npz archive...\")\n", + "np.savez_compressed('codalab_archive/sequence', **kw_dict)\n", + "print(time.time() - itime)\n", + "\n", + "# # Save as npy arrays\n", + "# itime = time.time()\n", + "# for chrom in kw_dict:\n", + "# np.save('sequence/{}.npy'.format(chrom), kw_dict[chrom])\n", + "# print(chrom, time.time() - itime)\n", + "\n", + "npz_archive = np.load('codalab_archive/sequence.npz')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## DNase" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "liver 0.006468534469604492\n", + "chr1 8.260387659072876\n", + "chr1 13.276052474975586\n", + "chr10 17.844778299331665\n", + "chr10 25.784512758255005\n", + "chr11 30.30143165588379\n", + "chr11 33.256701707839966\n", + "chr12 37.791435956954956\n", + "chr12 40.85292291641235\n", + "chr13 44.619521141052246\n", + "chr13 47.792500495910645\n", + "chr14 51.4214243888855\n", + "chr14 53.6813702583313\n", + "chr15 56.946401834487915\n", + "chr15 59.10466551780701\n", + "chr16 61.939475774765015\n", + "chr16 63.999470472335815\n", + "chr17 66.63648653030396\n", + "chr17 68.4126443862915\n", + "chr18 71.05454993247986\n", + "chr18 72.90085673332214\n", + "chr19 74.78594756126404\n", + "chr19 76.80954170227051\n", + "chr2 85.25815343856812\n", + "chr2 95.36479425430298\n", + "chr20 97.74516272544861\n", + "chr20 99.27151441574097\n", + "chr21 100.82207584381104\n", + "chr21 103.02815318107605\n", + "chr22 104.63926863670349\n", + "chr22 106.02127361297607\n", + "chr3 112.71910071372986\n", + "chr3 117.30491018295288\n", + "chr4 123.77405095100403\n", + "chr4 128.67069339752197\n", + "chr5 134.89299392700195\n", + "chr5 138.83413815498352\n", + "chr6 144.83386087417603\n", + "chr6 149.115407705307\n", + "chr7 154.4929392337799\n", + "chr7 157.8094253540039\n", + "chr8 162.8749077320099\n", + "chr8 165.9331293106079\n", + "chr9 170.5435709953308\n", + "chr9 173.46287417411804\n", + "chrX 178.5410988330841\n", + "chrX 185.49569463729858\n", + "chrY 187.14469981193542\n", + "chrY 189.6306025981903\n", + "MCF-7 0.01819300651550293\n", + "chr1 8.266149282455444\n", + "chr1 13.86928129196167\n", + "chr10 18.216674327850342\n", + "chr10 20.975315809249878\n", + "chr11 25.302175998687744\n", + "chr11 34.40013885498047\n", + "chr12 38.70525503158569\n", + "chr12 41.59175777435303\n", + "chr13 45.130286693573\n", + "chr13 47.67305374145508\n", + "chr14 51.26033353805542\n", + "chr14 53.59153509140015\n", + "chr15 56.858047008514404\n", + "chr15 59.08759665489197\n", + "chr16 62.03992414474487\n", + "chr16 63.99170207977295\n", + "chr17 67.05595779418945\n", + "chr17 69.3644654750824\n", + "chr18 71.78018283843994\n", + "chr18 73.58044695854187\n", + "chr19 75.70175457000732\n", + "chr19 79.72573828697205\n", + "chr2 87.675612449646\n", + "chr2 92.91672372817993\n", + "chr20 95.51653027534485\n", + "chr20 96.88600373268127\n", + "chr21 98.43806076049805\n", + "chr21 103.25369572639465\n", + "chr22 104.84882092475891\n", + "chr22 106.21143817901611\n", + "chr3 112.67947244644165\n", + "chr3 116.70610451698303\n", + "chr4 122.56520342826843\n", + "chr4 126.52856135368347\n", + "chr5 132.38469552993774\n", + "chr5 136.28370690345764\n", + "chr6 141.5743978023529\n", + "chr6 145.10061717033386\n", + "chr7 150.44007444381714\n", + "chr7 155.55760312080383\n", + "chr8 160.3683557510376\n", + "chr8 163.43416213989258\n", + "chr9 167.90313267707825\n", + "chr9 172.0667405128479\n", + "chrX 176.69336795806885\n", + "chrX 181.83150935173035\n", + "K562 0.007167339324951172\n", + "chr1 8.471662998199463\n", + "chr1 13.464861631393433\n", + "chr10 17.858335494995117\n", + "chr10 20.700791835784912\n", + "chr11 25.168848276138306\n", + "chr11 28.01260733604431\n", + "chr12 32.38129758834839\n", + "chr12 35.250038385391235\n", + "chr13 38.72063398361206\n", + "chr13 43.30442762374878\n", + "chr14 46.55065989494324\n", + "chr14 51.87103271484375\n", + "chr15 55.08980083465576\n", + "chr15 57.35198903083801\n", + "chr16 60.444990396499634\n", + "chr16 62.56146717071533\n", + "chr17 65.33607196807861\n", + "chr17 75.77480912208557\n", + "chr18 78.25007915496826\n", + "chr18 82.4424319267273\n", + "chr19 84.73718905448914\n", + "chr19 86.0900673866272\n", + "chr2 93.6916708946228\n", + "chr2 98.61803960800171\n", + "chr20 100.70567536354065\n", + "chr20 102.18551921844482\n", + "chr21 103.75095820426941\n", + "chr21 104.96330642700195\n", + "chr22 106.666348695755\n", + "chr22 108.20869731903076\n", + "chr3 114.6058874130249\n", + "chr3 123.16646194458008\n", + "chr4 129.07538533210754\n", + "chr4 135.95439338684082\n", + "chr5 141.63543701171875\n", + "chr5 148.8255476951599\n", + "chr6 154.68585968017578\n", + "chr6 160.3087387084961\n", + "chr7 165.7410364151001\n", + "chr7 169.09255123138428\n", + "chr8 173.68864274024963\n", + "chr8 176.73100185394287\n", + "chr9 181.10383462905884\n", + "chr9 184.0267071723938\n", + "chrX 188.59823846817017\n", + "chrX 191.7538366317749\n" + ] + } + ], + "source": [ + "### import pyBigWig\n", + "import glob\n", + "\n", + "dnases = {}\n", + "celltypes = ['A549', 'GM12878', 'H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']\n", + "\n", + "for ctype in celltypes:#glob.glob('dnase_bigwigs/*'):\n", + " itime = time.time()\n", + " # ctype = pth.split('/')[1].split('.')[1]\n", + " if ctype not in ['liver', 'MCF-7', 'K562']:\n", + " continue\n", + " bw = pyBigWig.open(\"dnase_bigwigs/DNASE.{}.fc.signal.bigwig\".format(ctype))\n", + " chromsizes = bw.chroms()\n", + " print(ctype, time.time() - itime)\n", + " dn_dict = {}\n", + " for chrom in chromsizes: #chr_IDs:\n", + " x = bw.values(chrom, 0, chromsizes[chrom], numpy=True)\n", + " dn_dict[chrom] = np.nan_to_num(x).astype(np.float16) # half-precision makes things significantly smaller (less time to load)\n", + " print(chrom, time.time() - itime)\n", + " \n", + " np.save('dnase/{}/{}.npy'.format(ctype, chrom), dn_dict[chrom])\n", + " print(chrom, time.time() - itime)\n", + " dnases[ctype] = dn_dict\n", + "\n", + "for ctype in dnases:\n", + " itime = time.time()\n", + " print(ctype)\n", + " dn_dict = dnases[ctype]\n", + " \n", + " # Save as npz archive\n", + " np.savez_compressed('codalab_archive/{}_dnase'.format(ctype), **dn_dict)\n", + " print(time.time() - itime)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 2", + "language": "python", + "name": "python2" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/dataset_preprocessing/encode-tfbs/prep_sequence.py b/dataset_preprocessing/encode-tfbs/prep_sequence.py new file mode 100644 index 00000000..5a0baea5 --- /dev/null +++ b/dataset_preprocessing/encode-tfbs/prep_sequence.py @@ -0,0 +1,151 @@ +import argparse, time +import numpy, pandas + +from tqdm import tqdm + + +def one_hot_encode(sequence, ignore='N', alphabet=None, dtype='int8', + verbose=False, **kwargs): + """Converts a string or list of characters into a one-hot encoding. + This function will take in either a string or a list and convert it into a + one-hot encoding. If the input is a string, each character is assumed to be + a different symbol, e.g. 'ACGT' is assumed to be a sequence of four + characters. If the input is a list, the elements can be any size. + Although this function will be used here primarily to convert nucleotide + sequences into one-hot encoding with an alphabet of size 4, in principle + this function can be used for any types of sequences. + Parameters + ---------- + sequence : str or list + The sequence to convert to a one-hot encoding. + ignore : str, optional + A character to indicate setting nothing to 1 for that row, keeping the + encoding entirely 0's for that row. In the context of genomics, this is + the N character. Default is 'N'. + alphabet : set or tuple or list, optional + A pre-defined alphabet. If None is passed in, the alphabet will be + determined from the sequence, but this may be time consuming for + large sequences. Default is None. + dtype : str or numpy.dtype, optional + The data type of the returned encoding. Default is int8. + verbose : bool or str, optional + Whether to display a progress bar. If a string is passed in, use as the + name of the progressbar. Default is False. + kwargs : arguments + Arguments to be passed into tqdm. Default is None. + Returns + ------- + ohe : numpy.ndarray + A binary matrix of shape (alphabet_size, sequence_length) where + alphabet_size is the number of unique elements in the sequence and + sequence_length is the length of the input sequence. + """ + + name = None if verbose in (True, False) else verbose + d = verbose is False + + if isinstance(sequence, str): + sequence = list(sequence) + + alphabet = alphabet or numpy.unique(sequence) + alphabet = [char for char in alphabet if char != ignore] + alphabet_lookup = {char: i for i, char in enumerate(alphabet)} + + ohe = numpy.zeros((len(sequence), len(alphabet)), dtype=dtype) + for i, char in tqdm(enumerate(sequence), disable=d, desc=name, **kwargs): + if char != ignore: + idx = alphabet_lookup[char] + ohe[i, idx] = 1 + + return ohe + + +def read_fasta(filename, include_chroms=None, exclude_chroms=None, + ignore='N', alphabet=['A', 'C', 'G', 'T', 'N'], verbose=True): + """Read in a FASTA file and output a dictionary of sequences. + This function will take in the path to a FASTA-formatted file and output + a string containing the sequence for each chromosome. Optionally, + the user can specify a set of chromosomes to include or exclude from + the returned dictionary. + Parameters + ---------- + filename : str + The path to the FASTA-formatted file to open. + include_chroms : set or tuple or list, optional + The exact names of chromosomes in the FASTA file to include, excluding + all others. If None, include all chromosomes (except those specified by + exclude_chroms). Default is None. + exclude_chroms : set or tuple or list, optional + The exact names of chromosomes in the FASTA file to exclude, including + all others. If None, include all chromosomes (or the set specified by + include_chroms). Default is None. + ignore : str, optional + A character to indicate setting nothing to 1 for that row, keeping the + encoding entirely 0's for that row. In the context of genomics, this is + the N character. Default is 'N'. + alphabet : set or tuple or list, optional + A pre-defined alphabet. If None is passed in, the alphabet will be + determined from the sequence, but this may be time consuming for + large sequences. Must include the ignore character. Default is + ['A', 'C', 'G', 'T', 'N']. + verbose : bool or str, optional + Whether to display a progress bar. If a string is passed in, use as the + name of the progressbar. Default is False. + Returns + ------- + chroms : dict + A dictionary of strings where the keys are the names of the + chromosomes (exact strings from the header lines in the FASTA file) + and the values are the strings encoded there. + """ + + sequences = {} + name, sequence = None, None + skip_chrom = False + + with open(filename, "r") as infile: + for line in tqdm(infile, disable=not verbose): + if line.startswith(">"): + if name is not None and skip_chrom is False: + sequences[name] = ''.join(sequence) + + sequence = [] + name = line[1:].strip("\n") + if include_chroms is not None and name not in include_chroms: + skip_chrom = True + elif exclude_chroms is not None and name in exclude_chroms: + skip_chrom = True + else: + skip_chrom = False + + else: + if skip_chrom == False: + sequence.append(line.rstrip("\n").upper()) + + return sequences + + +def generate_sequence_archive(seq_path='sequence/hg19.genome.fa', output_dir): + fasta_contents = read_fasta() + kw_dict = {} + itime = time.time() + for chrom in chr_IDs: + seqstr = fasta_contents[chrom] + kw_dict[chrom] = one_hot_encode(seqstr, alphabet=['A', 'C', 'G', 'T', 'N']) + print(chrom, time.time() - itime) + + # Save as npz archive; can take several (>20) minutes + print("Saving npz archive...") + np.savez_compressed('{}/sequence'.format(output_root), **kw_dict) + print(time.time() - itime) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--seq_path', required=True) + parser.add_argument('--output_dir', required=True) + args = parser.parse_args() + + generate_sequence_archive( + seq_path=args.seq_path, + output_dir=args.output_dir) \ No newline at end of file diff --git a/dataset_preprocessing/encode-tfbs/sandbox_data.ipynb b/sandbox_data.ipynb similarity index 97% rename from dataset_preprocessing/encode-tfbs/sandbox_data.ipynb rename to sandbox_data.ipynb index b2e74829..55a67da4 100644 --- a/dataset_preprocessing/encode-tfbs/sandbox_data.ipynb +++ b/sandbox_data.ipynb @@ -16,8 +16,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "57.5289368629\n", - "65.2459537983\n" + "50.2965240479\n", + "58.1326179504\n" ] } ], @@ -73,37 +73,37 @@ "name": "stdout", "output_type": "stream", "text": [ - "0.0467748641968\n", - "('chr1', 4.52302885055542)\n", - "('chr2', 8.645489931106567)\n", - "('chr3', 11.959153890609741)\n", - "('chr4', 15.15813684463501)\n", - "('chr5', 18.22238802909851)\n", - "('chr6', 21.19420099258423)\n", - "('chr7', 23.940655946731567)\n", - "('chr8', 26.415233850479126)\n", - "('chr9', 28.833614826202393)\n", - "('chr10', 31.08920383453369)\n", - "('chr11', 33.37020301818848)\n", - "('chr12', 35.98973989486694)\n", - "('chr13', 37.88540601730347)\n", - "('chr14', 39.68082284927368)\n", - "('chr15', 41.242313861846924)\n", - "('chr16', 42.74874496459961)\n", - "('chr17', 44.12280797958374)\n", - "('chr18', 45.46893382072449)\n", - "('chr19', 46.50577902793884)\n", - "('chr20', 47.59563183784485)\n", - "('chr21', 48.31779384613037)\n", - "('chr22', 49.17265295982361)\n", - "('chrX', 51.75806999206543)\n", - "('H1-hESC', 25.880441904067993)\n", - "('HCT116', 50.130937814712524)\n", - "('HeLa-S3', 75.29559993743896)\n", - "('HepG2', 102.25979495048523)\n", - "('K562', 128.43050694465637)\n", - "('A549', 154.80679488182068)\n", - "('GM12878', 182.0279529094696)\n" + "1.40137600899\n", + "('chr1', 4.365410089492798)\n", + "('chr2', 8.54686713218689)\n", + "('chr3', 11.915641069412231)\n", + "('chr4', 15.147382020950317)\n", + "('chr5', 18.221237182617188)\n", + "('chr6', 21.16081714630127)\n", + "('chr7', 23.87936806678772)\n", + "('chr8', 26.382845163345337)\n", + "('chr9', 28.802964210510254)\n", + "('chr10', 31.10539698600769)\n", + "('chr11', 33.392733097076416)\n", + "('chr12', 35.6597261428833)\n", + "('chr13', 37.56297421455383)\n", + "('chr14', 39.363978147506714)\n", + "('chr15', 41.089357137680054)\n", + "('chr16', 42.6117000579834)\n", + "('chr17', 43.9806342124939)\n", + "('chr18', 45.29493808746338)\n", + "('chr19', 46.26894497871399)\n", + "('chr20', 47.31300115585327)\n", + "('chr21', 48.139018058776855)\n", + "('chr22', 48.97876214981079)\n", + "('chrX', 51.61549210548401)\n", + "('H1-hESC', 24.14024806022644)\n", + "('HCT116', 47.97159004211426)\n", + "('HeLa-S3', 72.82926392555237)\n", + "('HepG2', 97.18733406066895)\n", + "('K562', 121.94148206710815)\n", + "('A549', 147.29550194740295)\n", + "('GM12878', 171.71312499046326)\n" ] } ], @@ -118,6 +118,7 @@ "for chrom in seq_arr:\n", " _seq_bp[chrom] = seq_arr[chrom]\n", " print(chrom, time.time() - itime)\n", + "print(\"Sequence read. Time: {}\".format(time.time() - itime))\n", "\n", "itime = time.time()\n", "_dnase_allcelltypes = {}\n", @@ -127,9 +128,17 @@ " _dnase_allcelltypes[ct] = {}\n", " for chrom in _seq_bp:\n", " _dnase_allcelltypes[ct][chrom] = dnase_npz_file[chrom]\n", - " print(ct, time.time() - itime)" + " print(ct, time.time() - itime)\n", + "print(\"DNase read for all celltypes. Time: {}\".format(time.time() - itime))" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": 10, diff --git a/sandbox_model.ipynb b/sandbox_model.ipynb index 885c8a59..2d62b55e 100644 --- a/sandbox_model.ipynb +++ b/sandbox_model.ipynb @@ -105,226 +105,294 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class Beagle2(nn.Module):\n", + " \"\"\"\n", + " Neural net models over genomic sequence.\n", + " Input:\n", + " - sequence_length: int (default 1000) \n", + " - Shape: (N, 5, sequence_length, 1) with batch size N.\n", + " \n", + " Output:\n", + " - prediction (Tensor): float torch tensor of shape (N, )\n", + " \n", + " TODO: Finish docstring.\n", + " \"\"\"\n", + " def __init__(self):\n", + " \"\"\"\n", + " Parameters\n", + " ----------\n", + " sequence_length : int\n", + " n_genomic_features : int\n", + " \"\"\"\n", + " super(Beagle2, self).__init__()\n", + "\n", + " self.dropout = 0.3\n", + " self.num_cell_types = 1\n", + " self.conv1 = nn.Conv2d(5, 300, (19, 1), stride = (1, 1), padding=(9,0))\n", + " self.conv2 = nn.Conv2d(300, 200, (11, 1), stride = (1, 1), padding = (5,0))\n", + " self.conv3 = nn.Conv2d(200, 200, (7, 1), stride = (1, 1), padding = (4,0))\n", + " self.bn1 = nn.BatchNorm2d(300)\n", + " self.bn2 = nn.BatchNorm2d(200)\n", + " self.bn3 = nn.BatchNorm2d(200)\n", + " self.maxpool1 = nn.MaxPool2d((3, 1))\n", + " self.maxpool2 = nn.MaxPool2d((4, 1))\n", + " self.maxpool3 = nn.MaxPool2d((4, 1))\n", + "\n", + " self.fc1 = nn.Linear(4200, 1000)\n", + " self.bn4 = nn.BatchNorm1d(1000)\n", + "\n", + " self.fc2 = nn.Linear(1000, 1000)\n", + " self.bn5 = nn.BatchNorm1d(1000)\n", + "\n", + " self.fc3 = nn.Linear(1000, self.num_cell_types)\n", + "\n", + " def forward(self, s):\n", + " s = s.permute(0, 2, 1).contiguous() # batch_size x 4 x 1000\n", + " s = s.view(-1, 5, 1000, 1) # batch_size x 4 x 1000 x 1 [4 channels]\n", + " s = self.maxpool1(F.relu(self.bn1(self.conv1(s)))) # batch_size x 300 x 333 x 1\n", + " s = self.maxpool2(F.relu(self.bn2(self.conv2(s)))) # batch_size x 200 x 83 x 1\n", + " s = self.maxpool3(F.relu(self.bn3(self.conv3(s)))) # batch_size x 200 x 21 x 1\n", + " s = s.view(-1, 4200)\n", + " conv_out = s\n", + "\n", + " s = F.dropout(F.relu(self.bn4(self.fc1(s))), p=self.dropout, training=self.training) # batch_size x 1000\n", + " #s = F.dropout(F.relu(self.bn5(self.fc2(s))), p=self.dropout, training=self.training) # batch_size x 1000\n", + " \n", + " \n", + " s = self.fc3(s)\n", + "\n", + " return s, conv_out\n", + "\n", + "\n", + "class DanQ(nn.Module):\n", + " def __init__(self, sequence_length, n_genomic_features):\n", + " \"\"\"\n", + " Parameters\n", + " ----------\n", + " sequence_length : int\n", + " Input sequence length\n", + " n_genomic_features : int\n", + " Total number of features to predict\n", + " \"\"\"\n", + " super(DanQ, self).__init__()\n", + " self.nnet = nn.Sequential(\n", + " nn.Conv1d(4, 320, kernel_size=26),\n", + " nn.ReLU(inplace=True),\n", + " nn.MaxPool1d(\n", + " kernel_size=13, stride=13),\n", + " nn.Dropout(0.2))\n", + "\n", + " self.bdlstm = nn.Sequential(\n", + " nn.LSTM(\n", + " 320, 320, num_layers=1, batch_first=True, bidirectional=True))\n", + "\n", + " self._n_channels = math.floor(\n", + " (sequence_length - 25) / 13)\n", + " self.classifier = nn.Sequential(\n", + " nn.Dropout(0.5),\n", + " nn.Linear(self._n_channels * 640, 925),\n", + " nn.ReLU(inplace=True),\n", + " nn.Linear(925, n_genomic_features),\n", + " nn.Sigmoid())\n", + "\n", + " def forward(self, x):\n", + " \"\"\"Forward propagation of a batch.\n", + " \"\"\"\n", + " out = self.nnet(x)\n", + " reshape_out = out.transpose(0, 1).transpose(0, 2)\n", + " out, _ = self.bdlstm(reshape_out)\n", + " out = out.transpose(0, 1)\n", + " reshape_out = out.contiguous().view(\n", + " out.size(0), 640 * self._n_channels)\n", + " predict = self.classifier(reshape_out)\n", + " return predict\n", + "\n", + "\n", + "class DeepSEA(nn.Module):\n", + " def __init__(self, sequence_length, n_genomic_features):\n", + " \"\"\"\n", + " Parameters\n", + " ----------\n", + " sequence_length : int\n", + " n_genomic_features : int\n", + " \"\"\"\n", + " super(DeepSEA, self).__init__()\n", + " conv_kernel_size = 8\n", + " pool_kernel_size = 4\n", + "\n", + " self.conv_net = nn.Sequential(\n", + " nn.Conv1d(4, 320, kernel_size=conv_kernel_size),\n", + " nn.ReLU(inplace=True),\n", + " nn.MaxPool1d(\n", + " kernel_size=pool_kernel_size, stride=pool_kernel_size),\n", + " nn.Dropout(p=0.2),\n", + "\n", + " nn.Conv1d(320, 480, kernel_size=conv_kernel_size),\n", + " nn.ReLU(inplace=True),\n", + " nn.MaxPool1d(\n", + " kernel_size=pool_kernel_size, stride=pool_kernel_size),\n", + " nn.Dropout(p=0.2),\n", + "\n", + " nn.Conv1d(480, 960, kernel_size=conv_kernel_size),\n", + " nn.ReLU(inplace=True),\n", + " nn.Dropout(p=0.5))\n", + "\n", + " reduce_by = conv_kernel_size - 1\n", + " pool_kernel_size = float(pool_kernel_size)\n", + " self.n_channels = int(\n", + " np.floor(\n", + " (np.floor(\n", + " (sequence_length - reduce_by) / pool_kernel_size)\n", + " - reduce_by) / pool_kernel_size)\n", + " - reduce_by)\n", + " self.classifier = nn.Sequential(\n", + " nn.Linear(960 * self.n_channels, n_genomic_features),\n", + " nn.ReLU(inplace=True),\n", + " nn.Linear(n_genomic_features, n_genomic_features),\n", + " nn.Sigmoid())\n", + "\n", + " def forward(self, x):\n", + " \"\"\"Forward propagation of a batch.\n", + " \"\"\"\n", + " out = self.conv_net(x)\n", + " reshape_out = out.view(out.size(0), 960 * self.n_channels)\n", + " predict = self.classifier(reshape_out)\n", + " return predict" + ] + }, + { + "cell_type": "code", + "execution_count": 78, + "metadata": {}, + "outputs": [], + "source": [ + "import math\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "class Beagle(nn.Module):\n", + " \"\"\"\n", + " Neural net models over genomic sequence.\n", + " Input:\n", + " - sequence_length: int (default 1000) \n", + " - Shape: (N, 5, sequence_length, 1) with batch size N.\n", + " \n", + " Output:\n", + " - prediction (Tensor): float torch tensor of shape (N, )\n", + " \n", + " TODO: Finish docstring.\n", + " \"\"\"\n", + " def __init__(self):\n", + " \"\"\"\n", + " Parameters\n", + " ----------\n", + " sequence_length : int\n", + " n_genomic_features : int\n", + " \"\"\"\n", + " super(Beagle, self).__init__()\n", + "\n", + " self.dropout = 0.3\n", + " self.num_cell_types = 1\n", + " self.conv1 = nn.Conv2d(5, 300, (19, 1), stride = (1, 1), padding=(9,0))\n", + " self.conv2 = nn.Conv2d(300, 200, (11, 1), stride = (1, 1), padding = (5,0))\n", + " self.conv3 = nn.Conv2d(200, 200, (7, 1), stride = (1, 1), padding = (4,0))\n", + " self.bn1 = nn.BatchNorm2d(300)\n", + " self.bn2 = nn.BatchNorm2d(200)\n", + " self.bn3 = nn.BatchNorm2d(200)\n", + " self.maxpool1 = nn.MaxPool2d((3, 1))\n", + " self.maxpool2 = nn.MaxPool2d((4, 1))\n", + " self.maxpool3 = nn.MaxPool2d((4, 1))\n", + "\n", + " self.fc1 = nn.Linear(4200, 1000)\n", + " self.bn4 = nn.BatchNorm1d(1000)\n", + "\n", + " self.fc2 = nn.Linear(1000, 1000)\n", + " self.bn5 = nn.BatchNorm1d(1000)\n", + "\n", + " self.fc3 = nn.Linear(1000, self.num_cell_types)\n", + "\n", + " def forward(self, s):\n", + " s = s.permute(0, 2, 1).contiguous() # batch_size x 5 x 1000\n", + " s = s.view(-1, 5, 1000, 1) # batch_size x 5 x 1000 x 1 [5 channels]\n", + " s = self.maxpool1(F.relu(self.bn1(self.conv1(s)))) # batch_size x 300 x 333 x 1\n", + " s = self.maxpool2(F.relu(self.bn2(self.conv2(s)))) # batch_size x 200 x 83 x 1\n", + " s = self.maxpool3(F.relu(self.bn3(self.conv3(s)))) # batch_size x 200 x 21 x 1\n", + " s = s.view(-1, 4200)\n", + " conv_out = s\n", + "\n", + " s = F.dropout(F.relu(self.bn4(self.fc1(s))), p=self.dropout, training=self.training) # batch_size x 1000\n", + " s = F.dropout(F.relu(self.bn5(self.fc2(s))), p=self.dropout, training=self.training) # batch_size x 1000\n", + " \n", + " s = self.fc3(s)\n", + "\n", + " return s, conv_out" + ] + }, + { + "cell_type": "code", + "execution_count": 86, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "{'A549': {'chr1': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr10': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr11': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr12': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr13': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr14': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr15': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr16': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr17': array([ 0.35986328, 0.35986328, 0.35986328, ..., 0. ,\n", - " 0. , 0. ], dtype=float16),\n", - " 'chr18': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr19': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr2': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr20': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr21': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr22': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr3': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr4': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr5': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr6': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr7': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr8': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr9': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chrX': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16)},\n", - " 'GM12878': {'chr1': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr10': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr11': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr12': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr13': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr14': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr15': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr16': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr17': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr18': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr19': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr2': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr20': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr21': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr22': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr3': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr4': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr5': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr6': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr7': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr8': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr9': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chrX': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16)},\n", - " 'H1-hESC': {'chr1': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr10': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr11': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr12': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr13': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr14': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr15': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr16': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr17': array([ 0.71972656, 0.71972656, 0.71972656, ..., 0. ,\n", - " 0. , 0. ], dtype=float16),\n", - " 'chr18': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr19': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr2': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr20': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr21': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr22': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr3': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr4': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr5': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr6': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr7': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr8': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr9': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chrX': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16)},\n", - " 'HCT116': {'chr1': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr10': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr11': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr12': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr13': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr14': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr15': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr16': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr17': array([ 0.80419922, 0.80419922, 0.80419922, ..., 0. ,\n", - " 0. , 0. ], dtype=float16),\n", - " 'chr18': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr19': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr2': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr20': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr21': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr22': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr3': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr4': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr5': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr6': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr7': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr8': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr9': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chrX': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16)},\n", - " 'HeLa-S3': {'chr1': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr10': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr11': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr12': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr13': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr14': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr15': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr16': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr17': array([ 0.71972656, 0.71972656, 0.71972656, ..., 0. ,\n", - " 0. , 0. ], dtype=float16),\n", - " 'chr18': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr19': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr2': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr20': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr21': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr22': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr3': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr4': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr5': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr6': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr7': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr8': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr9': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chrX': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16)},\n", - " 'HepG2': {'chr1': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr10': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr11': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr12': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr13': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr14': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr15': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr16': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr17': array([ 0.71972656, 0.71972656, 0.71972656, ..., 0. ,\n", - " 0. , 0. ], dtype=float16),\n", - " 'chr18': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr19': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr2': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr20': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr21': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr22': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr3': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr4': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr5': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr6': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr7': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr8': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr9': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chrX': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16)},\n", - " 'K562': {'chr1': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr10': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr11': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr12': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr13': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr14': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr15': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr16': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr17': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr18': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr19': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr2': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr20': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr21': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr22': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr3': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr4': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr5': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr6': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr7': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr8': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr9': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chrX': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16)}}" + "[('nnet.0.weight', 33280),\n", + " ('nnet.0.bias', 320),\n", + " ('bdlstm.0.weight_ih_l0', 409600),\n", + " ('bdlstm.0.weight_hh_l0', 409600),\n", + " ('bdlstm.0.bias_ih_l0', 1280),\n", + " ('bdlstm.0.bias_hh_l0', 1280),\n", + " ('bdlstm.0.weight_ih_l0_reverse', 409600),\n", + " ('bdlstm.0.weight_hh_l0_reverse', 409600),\n", + " ('bdlstm.0.bias_ih_l0_reverse', 1280),\n", + " ('bdlstm.0.bias_hh_l0_reverse', 1280),\n", + " ('classifier.1.weight', 592000),\n", + " ('classifier.1.bias', 925),\n", + " ('classifier.3.weight', 4625),\n", + " ('classifier.3.bias', 5)]" ] }, - "execution_count": 5, + "execution_count": 86, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "_dnase_allcelltypes" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "from examples.models import CNN_genome" + "def count_parameters(model):\n", + " return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", + "\n", + "model = Beagle2()\n", + "model = DanQ(50, 5)\n", + "\n", + "lst = [(x[0], x[1].numel()) for x in model.named_parameters()]\n", + "#np.sum([x[1] for x in lst])\n", + "count_parameters(model)\n", + "lst" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 48, "metadata": {}, "outputs": [ { - "ename": "TypeError", - "evalue": "unbound method parameters() must be called with Beagle instance as first argument (got nothing instead)", + "ename": "AttributeError", + "evalue": "'module' object has no attribute 'isin'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# def count_parameters(model):\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;31m# return sum(p.numel() for p in model.parameters() if p.requires_grad)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mCNN_genome\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mBeagle\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparameters\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;31mTypeError\u001b[0m: unbound method parameters() must be called with Beagle instance as first argument (got nothing instead)" + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mtr_chrs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m'chr2'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr9'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr11'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mte_chrs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m'chr1'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr8'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr21'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mtraining_df\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'chr'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtr_chrs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mval_df\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mval_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mval_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'chr'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mte_chrs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mall_df\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtraining_df\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_df\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mAttributeError\u001b[0m: 'module' object has no attribute 'isin'" ] } ], - "source": [ - "# def count_parameters(model):\n", - "# return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", - "CNN_genome.Beagle.parameters()" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [], "source": [ "tr_chrs = ['chr2', 'chr9', 'chr11']\n", "te_chrs = ['chr1', 'chr8', 'chr21']\n", @@ -337,6 +405,23 @@ "all_df = all_df[filter_msk]" ] }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1.12.1\n" + ] + } + ], + "source": [ + "print(np.__version__)" + ] + }, { "cell_type": "code", "execution_count": 30, @@ -528,17 +613,6 @@ "print(time.time() - itime)" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "python3 examples/run_expt.py -d camelyon17 --algorithm ERM --model densenet121 --split_scheme standard --groupby_fields hospital --loss_function cross_entropy --optimizer SGD \n", - "--lr 0.0001 --batch_size 32 --weight_decay 0 --n_epochs 10 --scheduler ReduceLROnPlateau --scheduler_metric_split val --scheduler_metric_name acc_avg \n", - "--log_dir log --log_every 50 --save_step 1000 --save_best --save_last --seed 0 --evaluate_all_splits --root_dir ROOTDIR" - ] - }, { "cell_type": "code", "execution_count": 156, diff --git a/wilds/datasets/camelyon17_dataset.py b/wilds/datasets/camelyon17_dataset.py index 0a76f615..d170ddf2 100644 --- a/wilds/datasets/camelyon17_dataset.py +++ b/wilds/datasets/camelyon17_dataset.py @@ -1,141 +1,141 @@ -import os -import torch -import pandas as pd -from PIL import Image -import numpy as np -from wilds.datasets.wilds_dataset import WILDSDataset -from wilds.common.grouper import CombinatorialGrouper -from wilds.common.metrics.all_metrics import Accuracy - -class Camelyon17Dataset(WILDSDataset): - """ - The CAMELYON17-wilds histopathology dataset. - This is a modified version of the original CAMELYON17 dataset. - - Supported `split_scheme`: - 'official' or 'in-dist' - - Input (x): - 96x96 image patches extracted from histopathology slides. - - Label (y): - y is binary. It is 1 if the central 32x32 region contains any tumor tissue, and 0 otherwise. - - Metadata: - Each patch is annotated with the ID of the hospital it came from (integer from 0 to 4) - and the slide it came from (integer from 0 to 49). - - Website: - https://camelyon17.grand-challenge.org/ - - Original publication: - @article{bandi2018detection, - title={From detection of individual metastases to classification of lymph node status at the patient level: the camelyon17 challenge}, - author={Bandi, Peter and Geessink, Oscar and Manson, Quirine and Van Dijk, Marcory and Balkenhol, Maschenka and Hermsen, Meyke and Bejnordi, Babak Ehteshami and Lee, Byungjae and Paeng, Kyunghyun and Zhong, Aoxiao and others}, - journal={IEEE transactions on medical imaging}, - volume={38}, - number={2}, - pages={550--560}, - year={2018}, - publisher={IEEE} - } - - License: - This dataset is in the public domain and is distributed under CC0. - https://creativecommons.org/publicdomain/zero/1.0/ - """ - - def __init__(self, root_dir='data', download=False, split_scheme='official'): - self._dataset_name = 'camelyon17' - self._version = '1.0' - self._download_url = 'https://worksheets.codalab.org/rest/bundles/0xe45e15f39fb54e9d9e919556af67aabe/contents/blob/' - self._compressed_size = 10_658_709_504 - self._data_dir = self.initialize_data_dir(root_dir, download) - self._original_resolution = (96,96) - - # Read in metadata - self._metadata_df = pd.read_csv( - os.path.join(self._data_dir, 'metadata.csv'), - index_col=0, - dtype={'patient': 'str'}) - - # Get the y values - self._y_array = torch.LongTensor(self._metadata_df['tumor'].values) - self._y_size = 1 - self._n_classes = 2 - - # Get filenames - self._input_array = [ - f'patches/patient_{patient}_node_{node}/patch_patient_{patient}_node_{node}_x_{x}_y_{y}.png' - for patient, node, x, y in - self._metadata_df.loc[:, ['patient', 'node', 'x_coord', 'y_coord']].itertuples(index=False, name=None)] - - # Extract splits - # Note that the hospital numbering here is different from what's in the paper, - # where to avoid confusing readers we used a 1-indexed scheme and just labeled the test hospital as 5. - # Here, the numbers are 0-indexed. - test_center = 2 - val_center = 1 - - self._split_dict = { - 'train': 0, - 'id_val': 1, - 'test': 2, - 'val': 3 - } - self._split_names = { - 'train': 'Train', - 'id_val': 'Validation (ID)', - 'test': 'Test', - 'val': 'Validation (OOD)', - } - centers = self._metadata_df['center'].values.astype('long') - num_centers = int(np.max(centers)) + 1 - val_center_mask = (self._metadata_df['center'] == val_center) - test_center_mask = (self._metadata_df['center'] == test_center) - self._metadata_df.loc[val_center_mask, 'split'] = self.split_dict['val'] - self._metadata_df.loc[test_center_mask, 'split'] = self.split_dict['test'] - - self._split_scheme = split_scheme - if self._split_scheme == 'official': - pass - elif self._split_scheme == 'in-dist': - # For the in-distribution oracle, - # we move slide 23 (corresponding to patient 042, node 3 in the original dataset) - # from the test set to the training set - slide_mask = (self._metadata_df['slide'] == 23) - self._metadata_df.loc[slide_mask, 'split'] = self.split_dict['train'] - else: - raise ValueError(f'Split scheme {self._split_scheme} not recognized') - self._split_array = self._metadata_df['split'].values - - self._metadata_array = torch.stack( - (torch.LongTensor(centers), - torch.LongTensor(self._metadata_df['slide'].values), - self._y_array), - dim=1) - self._metadata_fields = ['hospital', 'slide', 'y'] - - self._eval_grouper = CombinatorialGrouper( - dataset=self, - groupby_fields=['slide']) - - self._metric = Accuracy() - - super().__init__(root_dir, download, split_scheme) - - def get_input(self, idx): - """ - Returns x for a given idx. - """ - img_filename = os.path.join( - self.data_dir, - self._input_array[idx]) - x = Image.open(img_filename).convert('RGB') - return x - - def eval(self, y_pred, y_true, metadata): - return self.standard_group_eval( - self._metric, - self._eval_grouper, - y_pred, y_true, metadata) +import os +import torch +import pandas as pd +from PIL import Image +import numpy as np +from wilds.datasets.wilds_dataset import WILDSDataset +from wilds.common.grouper import CombinatorialGrouper +from wilds.common.metrics.all_metrics import Accuracy + +class Camelyon17Dataset(WILDSDataset): + """ + The CAMELYON17-wilds histopathology dataset. + This is a modified version of the original CAMELYON17 dataset. + + Supported `split_scheme`: + 'official' or 'in-dist' + + Input (x): + 96x96 image patches extracted from histopathology slides. + + Label (y): + y is binary. It is 1 if the central 32x32 region contains any tumor tissue, and 0 otherwise. + + Metadata: + Each patch is annotated with the ID of the hospital it came from (integer from 0 to 4) + and the slide it came from (integer from 0 to 49). + + Website: + https://camelyon17.grand-challenge.org/ + + Original publication: + @article{bandi2018detection, + title={From detection of individual metastases to classification of lymph node status at the patient level: the camelyon17 challenge}, + author={Bandi, Peter and Geessink, Oscar and Manson, Quirine and Van Dijk, Marcory and Balkenhol, Maschenka and Hermsen, Meyke and Bejnordi, Babak Ehteshami and Lee, Byungjae and Paeng, Kyunghyun and Zhong, Aoxiao and others}, + journal={IEEE transactions on medical imaging}, + volume={38}, + number={2}, + pages={550--560}, + year={2018}, + publisher={IEEE} + } + + License: + This dataset is in the public domain and is distributed under CC0. + https://creativecommons.org/publicdomain/zero/1.0/ + """ + + def __init__(self, root_dir='data', download=False, split_scheme='official'): + self._dataset_name = 'camelyon17' + self._version = '1.0' + self._download_url = 'https://worksheets.codalab.org/rest/bundles/0xe45e15f39fb54e9d9e919556af67aabe/contents/blob/' + self._compressed_size = 10_658_709_504 + self._data_dir = self.initialize_data_dir(root_dir, download) + self._original_resolution = (96,96) + + # Read in metadata + self._metadata_df = pd.read_csv( + os.path.join(self._data_dir, 'metadata.csv'), + index_col=0, + dtype={'patient': 'str'}) + + # Get the y values + self._y_array = torch.LongTensor(self._metadata_df['tumor'].values) + self._y_size = 1 + self._n_classes = 2 + + # Get filenames + self._input_array = [ + f'patches/patient_{patient}_node_{node}/patch_patient_{patient}_node_{node}_x_{x}_y_{y}.png' + for patient, node, x, y in + self._metadata_df.loc[:, ['patient', 'node', 'x_coord', 'y_coord']].itertuples(index=False, name=None)] + + # Extract splits + # Note that the hospital numbering here is different from what's in the paper, + # where to avoid confusing readers we used a 1-indexed scheme and just labeled the test hospital as 5. + # Here, the numbers are 0-indexed. + test_center = 2 + val_center = 1 + + self._split_dict = { + 'train': 0, + 'id_val': 1, + 'test': 2, + 'val': 3 + } + self._split_names = { + 'train': 'Train', + 'id_val': 'Validation (ID)', + 'test': 'Test', + 'val': 'Validation (OOD)', + } + centers = self._metadata_df['center'].values.astype('long') + num_centers = int(np.max(centers)) + 1 + val_center_mask = (self._metadata_df['center'] == val_center) + test_center_mask = (self._metadata_df['center'] == test_center) + self._metadata_df.loc[val_center_mask, 'split'] = self.split_dict['val'] + self._metadata_df.loc[test_center_mask, 'split'] = self.split_dict['test'] + + self._split_scheme = split_scheme + if self._split_scheme == 'official': + pass + elif self._split_scheme == 'in-dist': + # For the in-distribution oracle, + # we move slide 23 (corresponding to patient 042, node 3 in the original dataset) + # from the test set to the training set + slide_mask = (self._metadata_df['slide'] == 23) + self._metadata_df.loc[slide_mask, 'split'] = self.split_dict['train'] + else: + raise ValueError(f'Split scheme {self._split_scheme} not recognized') + self._split_array = self._metadata_df['split'].values + + self._metadata_array = torch.stack( + (torch.LongTensor(centers), + torch.LongTensor(self._metadata_df['slide'].values), + self._y_array), + dim=1) + self._metadata_fields = ['hospital', 'slide', 'y'] + + self._eval_grouper = CombinatorialGrouper( + dataset=self, + groupby_fields=['slide']) + + self._metric = Accuracy() + + super().__init__(root_dir, download, split_scheme) + + def get_input(self, idx): + """ + Returns x for a given idx. + """ + img_filename = os.path.join( + self.data_dir, + self._input_array[idx]) + x = Image.open(img_filename).convert('RGB') + return x + + def eval(self, y_pred, y_true, metadata): + return self.standard_group_eval( + self._metric, + self._eval_grouper, + y_pred, y_true, metadata) From 6e01e5caf49583853874b809ac8db8327366f6de Mon Sep 17 00:00:00 2001 From: aikanor Date: Mon, 8 Feb 2021 08:51:43 -0800 Subject: [PATCH 005/244] final integration 1/ --- .../encode-tfbs/prep_accessibility.py | 2 + .../encode-tfbs/prep_sequence.py | 2 + examples/configs/supported.py | 6 +- examples/models/CNN_genome.py | 2 +- examples/run_expt.py | 556 +++++++++--------- sandbox_data.ipynb | 24 + 6 files changed, 311 insertions(+), 281 deletions(-) diff --git a/dataset_preprocessing/encode-tfbs/prep_accessibility.py b/dataset_preprocessing/encode-tfbs/prep_accessibility.py index 9033224e..7342f797 100644 --- a/dataset_preprocessing/encode-tfbs/prep_accessibility.py +++ b/dataset_preprocessing/encode-tfbs/prep_accessibility.py @@ -3,6 +3,8 @@ from tqdm import tqdm +# Human chromosome names +chr_IDs = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX'] def generate_accessibility_archives(input_dir, output_dir): dnases = {} diff --git a/dataset_preprocessing/encode-tfbs/prep_sequence.py b/dataset_preprocessing/encode-tfbs/prep_sequence.py index 5a0baea5..7f396d9f 100644 --- a/dataset_preprocessing/encode-tfbs/prep_sequence.py +++ b/dataset_preprocessing/encode-tfbs/prep_sequence.py @@ -3,6 +3,8 @@ from tqdm import tqdm +# Human chromosome names +chr_IDs = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX'] def one_hot_encode(sequence, ignore='N', alphabet=None, dtype='int8', verbose=False, **kwargs): diff --git a/examples/configs/supported.py b/examples/configs/supported.py index bcbe54a9..d39e096e 100644 --- a/examples/configs/supported.py +++ b/examples/configs/supported.py @@ -7,6 +7,7 @@ from wilds.datasets.camelyon17_dataset import Camelyon17Dataset from wilds.datasets.celebA_dataset import CelebADataset from wilds.datasets.civilcomments_dataset import CivilCommentsDataset +from wilds.datasets.encodetfbs_dataset import EncodeTFBSDataset from wilds.datasets.fmow_dataset import FMoWDataset from wilds.datasets.iwildcam_dataset import IWildCamDataset from wilds.datasets.ogbmolpcba_dataset import OGBPCBADataset @@ -28,7 +29,8 @@ 'ogb-molpcba': OGBPCBADataset, 'poverty': PovertyMapDataset, 'fmow': FMoWDataset, - 'bdd100k': BDD100KDataset, + 'bdd100k': BDD100KDataset, + 'encodeTFBS': EncodeTFBSDataset, } losses = { @@ -47,7 +49,7 @@ # see initialize_*() functions for correspondence transforms = ['bert', 'image_base', 'image_resize_and_center_crop', 'poverty_train'] models = ['resnet18_ms', 'resnet50', 'resnet34', 'wideresnet50', 'densenet121', 'bert-base-uncased', 'gin-virtual', - 'logistic_regression'] + 'logistic_regression', 'beagle'] algorithms = ['ERM', 'groupDRO', 'deepCORAL', 'IRM'] optimizers = ['SGD', 'Adam', 'AdamW'] schedulers = ['linear_schedule_with_warmup', 'ReduceLROnPlateau', 'StepLR'] diff --git a/examples/models/CNN_genome.py b/examples/models/CNN_genome.py index 75295cd3..8a658eab 100644 --- a/examples/models/CNN_genome.py +++ b/examples/models/CNN_genome.py @@ -59,4 +59,4 @@ def forward(self, s): s = self.fc3(s) - return s, conv_out + return s#, conv_out diff --git a/examples/run_expt.py b/examples/run_expt.py index 166df04f..710157c3 100644 --- a/examples/run_expt.py +++ b/examples/run_expt.py @@ -1,278 +1,278 @@ -import os, csv -import time -import argparse -import pandas as pd -import torch -import torch.nn as nn -import torchvision -import sys -from collections import defaultdict - -from wilds.common.data_loaders import get_train_loader, get_eval_loader -from wilds.common.grouper import CombinatorialGrouper - -from utils import set_seed, Logger, BatchLogger, log_config, ParseKwargs, load, initialize_wandb, log_group_data, parse_bool -from train import train, evaluate -from algorithms.initializer import initialize_algorithm -from transforms import initialize_transform -from configs.utils import populate_defaults -import configs.supported as supported - -def main(): - ''' set default hyperparams in default_hyperparams.py ''' - parser = argparse.ArgumentParser() - - # Required arguments - parser.add_argument('-d', '--dataset', choices=supported.datasets, required=True) - parser.add_argument('--algorithm', required=True, choices=supported.algorithms) - parser.add_argument('--root_dir', required=True, - help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).') - - # Dataset - parser.add_argument('--split_scheme', help='Identifies how the train/val/test split is constructed. Choices are dataset-specific.') - parser.add_argument('--dataset_kwargs', nargs='*', action=ParseKwargs, default={}) - parser.add_argument('--download', default=False, type=parse_bool, const=True, nargs='?', - help='If true, tries to downloads the dataset if it does not exist in root_dir.') - parser.add_argument('--frac', type=float, default=1.0, - help='Convenience parameter that scales all dataset splits down to the specified fraction, for development purposes.') - - # Loaders - parser.add_argument('--loader_kwargs', nargs='*', action=ParseKwargs, default={}) - parser.add_argument('--train_loader', choices=['standard', 'group']) - parser.add_argument('--uniform_over_groups', type=parse_bool, const=True, nargs='?') - parser.add_argument('--distinct_groups', type=parse_bool, const=True, nargs='?') - parser.add_argument('--n_groups_per_batch', type=int) - parser.add_argument('--batch_size', type=int) - parser.add_argument('--eval_loader', choices=['standard'], default='standard') - - # Model - parser.add_argument('--model', choices=supported.models) - parser.add_argument('--model_kwargs', nargs='*', action=ParseKwargs, default={}, - help='keyword arguments for model initialization passed as key1=value1 key2=value2') - - # Transforms - parser.add_argument('--train_transform', choices=supported.transforms) - parser.add_argument('--eval_transform', choices=supported.transforms) - parser.add_argument('--target_resolution', nargs='+', type=int, help='target resolution. for example --target_resolution 224 224 for standard resnet.') - parser.add_argument('--resize_scale', type=float) - parser.add_argument('--max_token_length', type=int) - - # Objective - parser.add_argument('--loss_function', choices = supported.losses) - - # Algorithm - parser.add_argument('--groupby_fields', nargs='+') - parser.add_argument('--group_dro_step_size', type=float) - parser.add_argument('--coral_penalty_weight', type=float) - parser.add_argument('--irm_lambda', type=float) - parser.add_argument('--irm_penalty_anneal_iters', type=int) - parser.add_argument('--algo_log_metric') - - # Model selection - parser.add_argument('--val_metric') - parser.add_argument('--val_metric_decreasing', type=parse_bool, const=True, nargs='?') - - # Optimization - parser.add_argument('--n_epochs', type=int) - parser.add_argument('--optimizer', choices=supported.optimizers) - parser.add_argument('--lr', type=float) - parser.add_argument('--weight_decay', type=float) - parser.add_argument('--max_grad_norm', type=float) - parser.add_argument('--optimizer_kwargs', nargs='*', action=ParseKwargs, default={}) - - # Scheduler - parser.add_argument('--scheduler', choices=supported.schedulers) - parser.add_argument('--scheduler_kwargs', nargs='*', action=ParseKwargs, default={}) - parser.add_argument('--scheduler_metric_split', choices=['train', 'val'], default='val') - parser.add_argument('--scheduler_metric_name') - - # Evaluation - parser.add_argument('--evaluate_all_splits', type=parse_bool, const=True, nargs='?', default=True) - parser.add_argument('--eval_splits', nargs='+', default=[]) - parser.add_argument('--eval_only', type=parse_bool, const=True, nargs='?', default=False) - parser.add_argument('--eval_epoch', default=None, type=int) - - # Misc - parser.add_argument('--device', type=int, default=0) - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--log_dir', default='./logs') - parser.add_argument('--log_every', default=50, type=int) - parser.add_argument('--save_step', type=int) - parser.add_argument('--save_best', type=parse_bool, const=True, nargs='?', default=True) - parser.add_argument('--save_last', type=parse_bool, const=True, nargs='?', default=True) - parser.add_argument('--no_group_logging', type=parse_bool, const=True, nargs='?') - parser.add_argument('--use_wandb', type=parse_bool, const=True, nargs='?', default=False) - parser.add_argument('--progress_bar', type=parse_bool, const=True, nargs='?', default=False) - parser.add_argument('--resume', type=parse_bool, const=True, nargs='?', default=False) - - config = parser.parse_args() - config = populate_defaults(config) - - # set device - config.device = torch.device("cuda:" + str(config.device)) if torch.cuda.is_available() else torch.device("cpu") - - ## Initialize logs - if os.path.exists(config.log_dir) and config.resume: - resume=True - mode='a' - elif os.path.exists(config.log_dir) and config.eval_only: - resume=False - mode='a' - else: - resume=False - mode='w' - - if not os.path.exists(config.log_dir): - os.makedirs(config.log_dir) - logger = Logger(os.path.join(config.log_dir, 'log.txt'), mode) - - # Record config - log_config(config, logger) - - # Set random seed - set_seed(config.seed) - - # Data - full_dataset = supported.datasets[config.dataset]( - root_dir=config.root_dir, - download=config.download, - split_scheme=config.split_scheme, - **config.dataset_kwargs) - - # To implement data augmentation (i.e., have different transforms - # at training time vs. test time), modify these two lines: - train_transform = initialize_transform( - transform_name=config.train_transform, - config=config, - dataset=full_dataset) - eval_transform = initialize_transform( - transform_name=config.eval_transform, - config=config, - dataset=full_dataset) - - train_grouper = CombinatorialGrouper( - dataset=full_dataset, - groupby_fields=config.groupby_fields) - - datasets = defaultdict(dict) - for split in full_dataset.split_dict.keys(): - if split=='train': - transform = train_transform - verbose = True - elif split == 'val': - transform = eval_transform - verbose = True - else: - transform = eval_transform - verbose = False - # Get subset - datasets[split]['dataset'] = full_dataset.get_subset( - split, - frac=config.frac, - transform=transform) - - if split == 'train': - datasets[split]['loader'] = get_train_loader( - loader=config.train_loader, - dataset=datasets[split]['dataset'], - batch_size=config.batch_size, - uniform_over_groups=config.uniform_over_groups, - grouper=train_grouper, - distinct_groups=config.distinct_groups, - n_groups_per_batch=config.n_groups_per_batch, - **config.loader_kwargs) - else: - datasets[split]['loader'] = get_eval_loader( - loader=config.eval_loader, - dataset=datasets[split]['dataset'], - grouper=train_grouper, - batch_size=config.batch_size, - **config.loader_kwargs) - - # Set fields - datasets[split]['split'] = split - datasets[split]['name'] = full_dataset.split_names[split] - datasets[split]['verbose'] = verbose - # Loggers - # Loggers - datasets[split]['eval_logger'] = BatchLogger( - os.path.join(config.log_dir, f'{split}_eval.csv'), mode=mode, use_wandb=(config.use_wandb and verbose)) - datasets[split]['algo_logger'] = BatchLogger( - os.path.join(config.log_dir, f'{split}_algo.csv'), mode=mode, use_wandb=(config.use_wandb and verbose)) - - if config.use_wandb: - initialize_wandb(config) - - # Logging dataset info - if config.no_group_logging and full_dataset.is_classification and full_dataset.y_size==1: - log_grouper = CombinatorialGrouper( - dataset=full_dataset, - groupby_fields=['y']) - elif config.no_group_logging: - log_grouper = None - else: - log_grouper = train_grouper - log_group_data(datasets, log_grouper, logger) - - ## Initialize algorithm - algorithm = initialize_algorithm( - config=config, - datasets=datasets, - train_grouper=train_grouper) - - if not config.eval_only: - ## Load saved results if resuming - resume_success = False - if resume: - save_path = os.path.join(config.log_dir, 'last_model.pth') - if not os.path.exists(save_path): - epochs = [ - int(file.split('_')[0]) - for file in os.listdir(config.log_dir) if file.endswith('.pth')] - if len(epochs) > 0: - latest_epoch = max(epochs) - save_path = os.path.join(config.log_dir, f'{latest_epoch}_model.pth') - try: - prev_epoch, best_val_metric = load(algorithm, save_path) - epoch_offset = prev_epoch + 1 - logger.write(f'Resuming from epoch {epoch_offset} with best val metric {best_val_metric}') - resume_success = True - except FileNotFoundError: - pass - - if resume_success == False: - epoch_offset=0 - best_val_metric=None - - - train( - algorithm=algorithm, - datasets=datasets, - general_logger=logger, - config=config, - epoch_offset=epoch_offset, - best_val_metric=best_val_metric) - else: - if config.eval_epoch is None: - eval_model_path = os.path.join(config.log_dir, 'best_model.pth') - else: - eval_model_path = os.path.join(config.log_dir, f'{config.eval_epoch}_model.pth') - best_epoch, best_val_metric = load(algorithm, eval_model_path) - if config.eval_epoch is None: - epoch = best_epoch - else: - epoch = config.eval_epoch - evaluate( - algorithm=algorithm, - datasets=datasets, - epoch=epoch, - general_logger=logger, - config=config) - - logger.close() - for split in datasets: - datasets[split]['eval_logger'].close() - datasets[split]['algo_logger'].close() - -if __name__=='__main__': - main() +import os, csv +import time +import argparse +import pandas as pd +import torch +import torch.nn as nn +import torchvision +import sys +from collections import defaultdict + +from wilds.common.data_loaders import get_train_loader, get_eval_loader +from wilds.common.grouper import CombinatorialGrouper + +from utils import set_seed, Logger, BatchLogger, log_config, ParseKwargs, load, initialize_wandb, log_group_data, parse_bool +from train import train, evaluate +from algorithms.initializer import initialize_algorithm +from transforms import initialize_transform +from configs.utils import populate_defaults +import configs.supported as supported + +def main(): + ''' set default hyperparams in default_hyperparams.py ''' + parser = argparse.ArgumentParser() + + # Required arguments + parser.add_argument('-d', '--dataset', choices=supported.datasets, required=True) + parser.add_argument('--algorithm', required=True, choices=supported.algorithms) + parser.add_argument('--root_dir', required=True, + help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).') + + # Dataset + parser.add_argument('--split_scheme', help='Identifies how the train/val/test split is constructed. Choices are dataset-specific.') + parser.add_argument('--dataset_kwargs', nargs='*', action=ParseKwargs, default={}) + parser.add_argument('--download', default=False, type=parse_bool, const=True, nargs='?', + help='If true, tries to downloads the dataset if it does not exist in root_dir.') + parser.add_argument('--frac', type=float, default=1.0, + help='Convenience parameter that scales all dataset splits down to the specified fraction, for development purposes.') + + # Loaders + parser.add_argument('--loader_kwargs', nargs='*', action=ParseKwargs, default={}) + parser.add_argument('--train_loader', choices=['standard', 'group']) + parser.add_argument('--uniform_over_groups', type=parse_bool, const=True, nargs='?') + parser.add_argument('--distinct_groups', type=parse_bool, const=True, nargs='?') + parser.add_argument('--n_groups_per_batch', type=int) + parser.add_argument('--batch_size', type=int) + parser.add_argument('--eval_loader', choices=['standard'], default='standard') + + # Model + parser.add_argument('--model', choices=supported.models) + parser.add_argument('--model_kwargs', nargs='*', action=ParseKwargs, default={}, + help='keyword arguments for model initialization passed as key1=value1 key2=value2') + + # Transforms + parser.add_argument('--train_transform', choices=supported.transforms) + parser.add_argument('--eval_transform', choices=supported.transforms) + parser.add_argument('--target_resolution', nargs='+', type=int, help='target resolution. for example --target_resolution 224 224 for standard resnet.') + parser.add_argument('--resize_scale', type=float) + parser.add_argument('--max_token_length', type=int) + + # Objective + parser.add_argument('--loss_function', choices = supported.losses) + + # Algorithm + parser.add_argument('--groupby_fields', nargs='+') + parser.add_argument('--group_dro_step_size', type=float) + parser.add_argument('--coral_penalty_weight', type=float) + parser.add_argument('--irm_lambda', type=float) + parser.add_argument('--irm_penalty_anneal_iters', type=int) + parser.add_argument('--algo_log_metric') + + # Model selection + parser.add_argument('--val_metric') + parser.add_argument('--val_metric_decreasing', type=parse_bool, const=True, nargs='?') + + # Optimization + parser.add_argument('--n_epochs', type=int) + parser.add_argument('--optimizer', choices=supported.optimizers) + parser.add_argument('--lr', type=float) + parser.add_argument('--weight_decay', type=float) + parser.add_argument('--max_grad_norm', type=float) + parser.add_argument('--optimizer_kwargs', nargs='*', action=ParseKwargs, default={}) + + # Scheduler + parser.add_argument('--scheduler', choices=supported.schedulers) + parser.add_argument('--scheduler_kwargs', nargs='*', action=ParseKwargs, default={}) + parser.add_argument('--scheduler_metric_split', choices=['train', 'val'], default='val') + parser.add_argument('--scheduler_metric_name') + + # Evaluation + parser.add_argument('--evaluate_all_splits', type=parse_bool, const=True, nargs='?', default=True) + parser.add_argument('--eval_splits', nargs='+', default=[]) + parser.add_argument('--eval_only', type=parse_bool, const=True, nargs='?', default=False) + parser.add_argument('--eval_epoch', default=None, type=int) + + # Misc + parser.add_argument('--device', type=int, default=0) + parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--log_dir', default='./logs') + parser.add_argument('--log_every', default=50, type=int) + parser.add_argument('--save_step', type=int) + parser.add_argument('--save_best', type=parse_bool, const=True, nargs='?', default=True) + parser.add_argument('--save_last', type=parse_bool, const=True, nargs='?', default=True) + parser.add_argument('--no_group_logging', type=parse_bool, const=True, nargs='?') + parser.add_argument('--use_wandb', type=parse_bool, const=True, nargs='?', default=False) + parser.add_argument('--progress_bar', type=parse_bool, const=True, nargs='?', default=False) + parser.add_argument('--resume', type=parse_bool, const=True, nargs='?', default=False) + + config = parser.parse_args() + config = populate_defaults(config) + + # set device + config.device = torch.device("cuda:" + str(config.device)) if torch.cuda.is_available() else torch.device("cpu") + + ## Initialize logs + if os.path.exists(config.log_dir) and config.resume: + resume=True + mode='a' + elif os.path.exists(config.log_dir) and config.eval_only: + resume=False + mode='a' + else: + resume=False + mode='w' + + if not os.path.exists(config.log_dir): + os.makedirs(config.log_dir) + logger = Logger(os.path.join(config.log_dir, 'log.txt'), mode) + + # Record config + log_config(config, logger) + + # Set random seed + set_seed(config.seed) + + # Data + full_dataset = supported.datasets[config.dataset]( + root_dir=config.root_dir, + download=config.download, + split_scheme=config.split_scheme, + **config.dataset_kwargs) + + # To implement data augmentation (i.e., have different transforms + # at training time vs. test time), modify these two lines: + train_transform = initialize_transform( + transform_name=config.train_transform, + config=config, + dataset=full_dataset) + eval_transform = initialize_transform( + transform_name=config.eval_transform, + config=config, + dataset=full_dataset) + + train_grouper = CombinatorialGrouper( + dataset=full_dataset, + groupby_fields=config.groupby_fields) + + datasets = defaultdict(dict) + for split in full_dataset.split_dict.keys(): + if split=='train': + transform = train_transform + verbose = True + elif split == 'val': + transform = eval_transform + verbose = True + else: + transform = eval_transform + verbose = False + # Get subset + datasets[split]['dataset'] = full_dataset.get_subset( + split, + frac=config.frac, + transform=transform) + + if split == 'train': + datasets[split]['loader'] = get_train_loader( + loader=config.train_loader, + dataset=datasets[split]['dataset'], + batch_size=config.batch_size, + uniform_over_groups=config.uniform_over_groups, + grouper=train_grouper, + distinct_groups=config.distinct_groups, + n_groups_per_batch=config.n_groups_per_batch, + **config.loader_kwargs) + else: + datasets[split]['loader'] = get_eval_loader( + loader=config.eval_loader, + dataset=datasets[split]['dataset'], + grouper=train_grouper, + batch_size=config.batch_size, + **config.loader_kwargs) + + # Set fields + datasets[split]['split'] = split + datasets[split]['name'] = full_dataset.split_names[split] + datasets[split]['verbose'] = verbose + # Loggers + # Loggers + datasets[split]['eval_logger'] = BatchLogger( + os.path.join(config.log_dir, f'{split}_eval.csv'), mode=mode, use_wandb=(config.use_wandb and verbose)) + datasets[split]['algo_logger'] = BatchLogger( + os.path.join(config.log_dir, f'{split}_algo.csv'), mode=mode, use_wandb=(config.use_wandb and verbose)) + + if config.use_wandb: + initialize_wandb(config) + + # Logging dataset info + if config.no_group_logging and full_dataset.is_classification and full_dataset.y_size==1: + log_grouper = CombinatorialGrouper( + dataset=full_dataset, + groupby_fields=['y']) + elif config.no_group_logging: + log_grouper = None + else: + log_grouper = train_grouper + log_group_data(datasets, log_grouper, logger) + + ## Initialize algorithm + algorithm = initialize_algorithm( + config=config, + datasets=datasets, + train_grouper=train_grouper) + + if not config.eval_only: + ## Load saved results if resuming + resume_success = False + if resume: + save_path = os.path.join(config.log_dir, 'last_model.pth') + if not os.path.exists(save_path): + epochs = [ + int(file.split('_')[0]) + for file in os.listdir(config.log_dir) if file.endswith('.pth')] + if len(epochs) > 0: + latest_epoch = max(epochs) + save_path = os.path.join(config.log_dir, f'{latest_epoch}_model.pth') + try: + prev_epoch, best_val_metric = load(algorithm, save_path) + epoch_offset = prev_epoch + 1 + logger.write(f'Resuming from epoch {epoch_offset} with best val metric {best_val_metric}') + resume_success = True + except FileNotFoundError: + pass + + if resume_success == False: + epoch_offset=0 + best_val_metric=None + + + train( + algorithm=algorithm, + datasets=datasets, + general_logger=logger, + config=config, + epoch_offset=epoch_offset, + best_val_metric=best_val_metric) + else: + if config.eval_epoch is None: + eval_model_path = os.path.join(config.log_dir, 'best_model.pth') + else: + eval_model_path = os.path.join(config.log_dir, f'{config.eval_epoch}_model.pth') + best_epoch, best_val_metric = load(algorithm, eval_model_path) + if config.eval_epoch is None: + epoch = best_epoch + else: + epoch = config.eval_epoch + evaluate( + algorithm=algorithm, + datasets=datasets, + epoch=epoch, + general_logger=logger, + config=config) + + logger.close() + for split in datasets: + datasets[split]['eval_logger'].close() + datasets[split]['algo_logger'].close() + +if __name__=='__main__': + main() diff --git a/sandbox_data.ipynb b/sandbox_data.ipynb index 55a67da4..ad5ae4bd 100644 --- a/sandbox_data.ipynb +++ b/sandbox_data.ipynb @@ -1,5 +1,29 @@ { "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- examples\n", + " - run_expt.py\n", + " - configs\n", + " - [x] supported.py\n", + " - [ ] model.py\n", + " - [ ] datasets.py\n", + " - models\n", + " - [x] CNN_genome.py\n", + " - train.py\n", + " - utils.py\n", + "- wilds\n", + " - [x] datasets/encodetfbs_dataset.py\n", + " - common\n", + " - metrics\n", + " - [ ] all_metrics.py\n", + " - data_loaders.py\n", + " - grouper.py\n", + " - [ ] utils.py ( threshold_at_recall() )" + ] + }, { "cell_type": "markdown", "metadata": {}, From 45491f23d11f31128bc5e147cb061c5cd394b53e Mon Sep 17 00:00:00 2001 From: aikanor Date: Mon, 8 Feb 2021 16:45:18 -0800 Subject: [PATCH 006/244] integration 2/ --- examples/configs/model.py | 3 +- sandbox_data.ipynb | 387 +-------------------------- wilds/datasets/camelyon17_dataset.py | 138 ++++++++++ wilds/datasets/encodetfbs_dataset.py | 5 +- wilds/version.py | 53 ++-- 5 files changed, 175 insertions(+), 411 deletions(-) diff --git a/examples/configs/model.py b/examples/configs/model.py index 12a429a7..31539d03 100644 --- a/examples/configs/model.py +++ b/examples/configs/model.py @@ -26,5 +26,6 @@ 'resnet18_ms': { 'target_resolution': (224, 224), }, - 'logistic_regression': {}, + 'logistic_regression': {}, + 'beagle': {}, } diff --git a/sandbox_data.ipynb b/sandbox_data.ipynb index ad5ae4bd..0a9806a6 100644 --- a/sandbox_data.ipynb +++ b/sandbox_data.ipynb @@ -24,6 +24,13 @@ " - [ ] utils.py ( threshold_at_recall() )" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "markdown", "metadata": {}, @@ -579,386 +586,6 @@ } ], "source": [] - }, - { - "cell_type": "code", - "execution_count": 165, - "metadata": {}, - "outputs": [ - { - "ename": "ModuleNotFoundError", - "evalue": "No module named 'torch_scatter'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minsert\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'..'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 16\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mwilds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata_loaders\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mget_train_loader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mget_eval_loader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 17\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mwilds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrouper\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mCombinatorialGrouper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mwilds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mget_counts\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/dr_benchmark/wilds/common/data_loaders.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mDataLoader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msampler\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mWeightedRandomSampler\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mSubsetRandomSampler\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mwilds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mget_counts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msplit_into_groups\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mIPython\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/dr_benchmark/wilds/common/utils.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch_scatter\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnumpy\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mSubset\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mpandas\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapi\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtypes\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mCategoricalDtype\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'torch_scatter'" - ] - } - ], - "source": [ - "import os, csv\n", - "import time\n", - "import argparse\n", - "import IPython\n", - "import pandas as pd\n", - "import torch\n", - "import torch.nn as nn\n", - "import torchvision\n", - "import sys\n", - "from collections import defaultdict\n", - "# torch.multiprocessing.set_sharing_strategy('file_system')\n", - "\n", - "# TODO: Replace this once we make wilds into an installed package\n", - "sys.path.insert(1, os.path.join(sys.path[0], '..'))\n", - "\n", - "from wilds.common.data_loaders import get_train_loader, get_eval_loader\n", - "from wilds.common.grouper import CombinatorialGrouper\n", - "from wilds.common.utils import get_counts\n", - "\n", - "from models.model_attributes import model_attributes\n", - "from utils import set_seed, Logger, BatchLogger, log_args, ParseKwargs, load\n", - "from train import train, evaluate\n", - "from data import dataset_attributes\n", - "from optimizer import optimizer_attributes\n", - "from scheduler import scheduler_attributes\n", - "from loss import losses\n", - "from utils import log_group_data\n", - "from algorithms.constructors import algorithm_constructors" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "from examples.models.model_attributes import model_attributes" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "ename": "ModuleNotFoundError", - "evalue": "No module named 'utils'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel_attributes\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mmodel_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mset_seed\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mLogger\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mCSVBatchLogger\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlog_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mParseKwargs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mload\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 21\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 22\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mdataset_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0moptimizer_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/dr_benchmark/examples/train.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0msave\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mIPython\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'utils'" - ] - } - ], - "source": [ - "def initialize_algorithm(args, datasets, train_grouper):\n", - " train_dataset = datasets['train']['dataset']\n", - " train_loader = datasets['train']['loader']\n", - "\n", - " # Configure the final layer of the networks used\n", - " # The code below are defaults. Edit this if you need special config for your model.\n", - " if (train_dataset.is_classification) and (train_dataset.y_size == 1):\n", - " # For single-task classification, we have one output per class\n", - " d_out = train_dataset.n_classes\n", - " elif (train_dataset.is_classification) and (train_dataset.y_size > 1) and (train_dataset.n_classes == 2):\n", - " # For multi-task binary classification (each output is the logit for each binary class)\n", - " d_out = train_dataset.y_size\n", - " elif (not train_dataset.is_classification):\n", - " # For regression, we have one output per target dimension\n", - " d_out = train_dataset.y_size\n", - " else:\n", - " raise RuntimeError('d_out not defined.')\n", - " \n", - "\n", - " # Sanity checking input args\n", - " if args.algorithm == 'groupDRO':\n", - " assert args.train_loader_kwargs['uniform_over_groups']\n", - " elif args.algorithm in ['deepCORAL', 'IRM']:\n", - " assert args.train_loader == 'group'\n", - " assert args.train_loader_kwargs['uniform_over_groups']\n", - " assert args.train_loader_kwargs['distinct_groups']\n", - "\n", - " # Other config\n", - " n_train_steps = len(train_loader) * args.n_epochs\n", - "# prediction_fn = dataset_attributes[args.dataset]['prediction_fn']\n", - " loss = losses[args.loss_function]\n", - " metric = dataset_attributes[args.dataset]['metric']\n", - " train_g = train_grouper.metadata_to_group(train_dataset.metadata_array)\n", - " is_group_in_train = get_counts(train_g, train_grouper.n_groups) > 0\n", - " algorithm_constructor = algorithm_constructors[args.algorithm]\n", - " algorithm = algorithm_constructor(\n", - " args=args,\n", - " d_out=d_out,\n", - " grouper=train_grouper,\n", - " loss=loss,\n", - " metric=metric,\n", - " n_train_steps=n_train_steps,\n", - " is_group_in_train=is_group_in_train)\n", - " return algorithm" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def main():\n", - " parser = argparse.ArgumentParser()\n", - "\n", - " # Dataset\n", - " parser.add_argument('-d', '--dataset', choices=dataset_attributes.keys(), required=True)\n", - " parser.add_argument('--split_scheme', default='standard',\n", - " help='Identifies how the train/val/test split is constructed. Choices are dataset-specific.')\n", - " parser.add_argument('--dataset_kwargs', nargs='*', action=ParseKwargs, default={})\n", - " parser.add_argument('--root_dir', default=None, required=True,\n", - " help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).')\n", - " parser.add_argument('--download', default=False, action='store_true',\n", - " help='If true, tries to downloads the dataset if it does not exist in root_dir.')\n", - " parser.add_argument('--frac', type=float, default=1.0,\n", - " help='Convenience parameter that scales all dataset splits down to the specified fraction, for development purposes.')\n", - "\n", - " # Loaders\n", - " parser.add_argument('--train_loader', choices=['standard', 'group'], default='standard')\n", - " parser.add_argument('--train_loader_kwargs', nargs='*', action=ParseKwargs, default={})\n", - " parser.add_argument('--eval_loader', choices=['standard', 'group'], default='standard')\n", - " parser.add_argument('--batch_size', type=int, default=32)\n", - " parser.add_argument('--no_pin_memory', action='store_true') # TODO: put as loader_kwargs\n", - " parser.add_argument('--num_workers', type=int, default=4) # TODO: put as loader kwargs\n", - "\n", - " # Model\n", - " parser.add_argument(\n", - " '--model',\n", - " choices=model_attributes.keys(),\n", - " default='resnet50')\n", - " parser.add_argument('--model_kwargs', nargs='*', action=ParseKwargs, default={},\n", - " help='keyword arguments for model initialization passed as key1=value1 key2=value2')\n", - " parser.add_argument('--train_from_scratch', action='store_true', default=False)\n", - "\n", - " # Algorithm and objective\n", - " parser.add_argument('--algorithm', required=True, choices=algorithm_constructors.keys())\n", - " parser.add_argument('--algorithm_kwargs', nargs='*', action=ParseKwargs, default={})\n", - " parser.add_argument('--groupby_fields', nargs='+', default=None)\n", - " parser.add_argument('--loss_function', required=True, choices = losses.keys()) #TODO: make default\n", - " parser.add_argument('--val_metric', default=None)\n", - "\n", - " # Optimization\n", - " parser.add_argument('--n_epochs', type=int, default=4)\n", - " parser.add_argument('--optimizer', default=None, choices=optimizer_attributes.keys())\n", - " parser.add_argument('--lr', type=float, required=True)\n", - " parser.add_argument('--weight_decay', type=float, required=True)\n", - " parser.add_argument('--optimizer_kwargs', nargs='*', action=ParseKwargs, default={})\n", - " parser.add_argument('--scheduler', default=None, choices=scheduler_attributes.keys())\n", - " parser.add_argument('--scheduler_kwargs', nargs='*', action=ParseKwargs, default={})\n", - " parser.add_argument('--scheduler_metric_split', choices=['train', 'val'], default='val')\n", - " parser.add_argument('--scheduler_metric_name')\n", - "\n", - " # Evaluation\n", - " parser.add_argument('--evaluate_all_splits', action='store_true', default=False)\n", - " parser.add_argument('--additional_eval_splits', nargs='+', default=[])\n", - "\n", - " # Misc\n", - " parser.add_argument('--device', type=int, default=0)\n", - " parser.add_argument('--seed', type=int, default=0)\n", - " parser.add_argument('--log_dir', default='./logs')\n", - " parser.add_argument('--log_every', default=50, type=int)\n", - " parser.add_argument('--save_step', type=int, default=None)\n", - " parser.add_argument('--save_best', action='store_true', default=False)\n", - " parser.add_argument('--save_last', action='store_true', default=False)\n", - " parser.add_argument('--save_outputs', action='store_true', default=False)\n", - " parser.add_argument('--no_group_logging', action='store_true', default=False)\n", - " parser.add_argument('--val_metric_decreasing', action='store_true', default=False)\n", - " parser.add_argument('--use_wandb', action='store_true', default=False)\n", - " parser.add_argument('--progress_bar', action='store_true', default=False)\n", - " parser.add_argument('--resume', default=False, action='store_true')\n", - " parser.add_argument('--eval_only', default=False, action='store_true')\n", - "\n", - " args = parser.parse_args()\n", - "\n", - " # set device\n", - " args.device = torch.device(\"cuda:\" + str(args.device)) if torch.cuda.is_available() else torch.device(\"cpu\")\n", - "\n", - " # Set defaults\n", - " if args.groupby_fields is None:\n", - " args.no_group_logging = True\n", - " if args.val_metric is None:\n", - " args.val_metric = dataset_attributes[args.dataset]['val_metric']\n", - "\n", - " ## Initialize logs\n", - " if os.path.exists(args.log_dir) and args.resume:\n", - " resume=True\n", - " mode='a'\n", - " else:\n", - " resume=False\n", - " mode='w'\n", - " if not os.path.exists(args.log_dir):\n", - " os.makedirs(args.log_dir)\n", - " logger = Logger(os.path.join(args.log_dir, 'log.txt'), mode)\n", - "\n", - " # Record args\n", - " log_args(args, logger)\n", - "\n", - " # Set random seed\n", - " set_seed(args.seed)\n", - "\n", - " # Data\n", - " full_dataset = dataset_attributes[args.dataset]['constructor'](\n", - " root_dir=args.root_dir,\n", - " download=args.download,\n", - " split_scheme=args.split_scheme,\n", - " **args.dataset_kwargs)\n", - "\n", - " # To implement data augmentation (i.e., have different transforms\n", - " # at training time vs. test time), modify these two lines:\n", - " train_transform = dataset_attributes[args.dataset]['transform'](args.model)\n", - " if dataset_attributes[args.dataset].get('eval_transform') is None:\n", - " eval_transform = dataset_attributes[args.dataset]['transform'](args.model)\n", - " else:\n", - " eval_transform = dataset_attributes[args.dataset]['eval_transform'](args.model)\n", - "\n", - " train_grouper = CombinatorialGrouper(\n", - " dataset=full_dataset,\n", - " groupby_fields=args.groupby_fields)\n", - "\n", - " datasets = defaultdict(dict)\n", - " for split in full_dataset.split_dict.keys():\n", - " if split=='train':\n", - " transform = train_transform\n", - " verbose = True\n", - " elif split == 'val':\n", - " transform = eval_transform\n", - " verbose = True\n", - " else:\n", - " transform = eval_transform\n", - " verbose = False\n", - " # Get subset\n", - " datasets[split]['dataset'] = full_dataset.get_subset(\n", - " split,\n", - " frac=args.frac,\n", - " transform=transform)\n", - "\n", - " # Get loader\n", - " shared_loader_kwargs = {\n", - " 'num_workers': args.num_workers,\n", - " 'pin_memory': not args.no_pin_memory,\n", - " 'batch_size': args.batch_size,\n", - " 'collate_fn': dataset_attributes[args.dataset]['collate']\n", - " }\n", - "\n", - " if split == 'train':\n", - " datasets[split]['loader'] = get_train_loader(\n", - " loader=args.train_loader,\n", - " dataset=datasets[split]['dataset'],\n", - " grouper=train_grouper,\n", - " train_loader_kwargs=args.train_loader_kwargs,\n", - " **shared_loader_kwargs)\n", - " else:\n", - " datasets[split]['loader'] = get_eval_loader(\n", - " loader=args.eval_loader,\n", - " dataset=datasets[split]['dataset'],\n", - " grouper=train_grouper,\n", - " **shared_loader_kwargs)\n", - "\n", - " # Set fields\n", - " datasets[split]['split'] = split\n", - " datasets[split]['name'] = full_dataset.split_names[split]\n", - " datasets[split]['verbose'] = verbose\n", - " # Loggers\n", - " # Loggers\n", - " datasets[split]['eval_logger'] = BatchLogger(\n", - " os.path.join(args.log_dir, f'{split}_eval.csv'), mode=mode, use_wandb=args.use_wandb)\n", - " datasets[split]['algo_logger'] = BatchLogger(\n", - " os.path.join(args.log_dir, f'{split}_algo.csv'), mode=mode, use_wandb=args.use_wandb)\n", - "\n", - " if args.use_wandb:\n", - " initialize_wandb(args)\n", - "\n", - " # Logging dataset info\n", - " if args.no_group_logging and full_dataset.is_classification and full_dataset.y_size==1:\n", - " log_grouper = CombinatorialGrouper(\n", - " dataset=full_dataset,\n", - " groupby_fields=['y'])\n", - " elif args.no_group_logging:\n", - " log_grouper = None\n", - " else:\n", - " log_grouper = train_grouper\n", - " log_group_data(args, datasets, log_grouper, logger)\n", - "\n", - " ## Initialize algorithm\n", - " algorithm = initialize_algorithm(args, datasets, train_grouper)\n", - "\n", - " if not args.eval_only:\n", - " ## Load saved results if resuming\n", - " resume_success = False\n", - " if resume:\n", - " save_path = os.path.join(args.log_dir, 'last_model.pth')\n", - " if not os.path.exists(save_path):\n", - " epochs = [\n", - " int(file.split('_')[0])\n", - " for file in os.listdir(args.log_dir) if file.endswith('.pth')]\n", - " if len(epochs) > 0:\n", - " latest_epoch = max(epochs)\n", - " save_path = os.path.join(args.log_dir, f'{latest_epoch}_model.pth')\n", - " try:\n", - " prev_epoch, best_val_metric = load(algorithm, save_path)\n", - " epoch_offset = prev_epoch + 1\n", - " logger.write(f'Resuming from epoch {epoch_offset} with best val metric {best_val_metric}')\n", - " resume_success = True\n", - " except FileNotFoundError:\n", - " pass\n", - "\n", - " if resume_success == False:\n", - " epoch_offset=0\n", - " best_val_metric=None\n", - "\n", - "\n", - " train(algorithm,\n", - " datasets,\n", - " logger,\n", - " args,\n", - " epoch_offset=epoch_offset,\n", - " best_val_metric=best_val_metric)\n", - " else:\n", - " best_model_path = os.path.join(args.log_dir, 'best_model.pth')\n", - " best_epoch, best_val_metric = load(algorithm, best_model_path)\n", - " evaluate(algorithm, datasets, best_epoch, logger)\n", - "\n", - " logger.close()\n", - " for split in datasets:\n", - " datasets[split]['eval_logger'].close()\n", - " datasets[split]['algo_logger'].close()\n", - "\n", - "if __name__=='__main__':\n", - " main()\n" - ] } ], "metadata": { diff --git a/wilds/datasets/camelyon17_dataset.py b/wilds/datasets/camelyon17_dataset.py index d170ddf2..691eac5d 100644 --- a/wilds/datasets/camelyon17_dataset.py +++ b/wilds/datasets/camelyon17_dataset.py @@ -139,3 +139,141 @@ def eval(self, y_pred, y_true, metadata): self._metric, self._eval_grouper, y_pred, y_true, metadata) + + +class EncodeTFBSDataset(WILDSDataset): + """ + EncodeTFBS dataset + Website: https://www.synapse.org/#!Synapse:syn6131484 + """ + + def __init__(self, root_dir, download, split_scheme): + self._dataset_name = 'encodeTFBS' + self._download_url = 'https://worksheets.codalab.org/rest/bundles/0x8b3255e21e164cd98d3aeec09cd0bc26/contents/blob/' + self._data_dir = self.initialize_data_dir(root_dir, download) + self._y_size = 1 + self._n_classes = 2 + + self._tr_chrs = ['chr2', 'chr9', 'chr11'] + self._te_chrs = ['chr1', 'chr8', 'chr21'] + self._transcription_factor = 'MAX' + self._train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562'] + self._val_celltype = ['A549'] + self._test_celltype = ['GM12878'] + self._all_celltypes = self._train_celltypes + self._val_celltype + self._test_celltype + + self._metadata_fields = ['chr', 'celltype', 'y'] + self._metadata_map = {} + self._metadata_map['chr'] = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX'] + self._metadata_map['celltype'] = self._all_celltypes + + # Load sequence and DNase features + sequence_filename = os.path.join(self._data_dir, 'sequence.npz') + seq_arr = np.load(sequence_filename) + self._seq_bp = {} + for chrom in seq_arr: + self._seq_bp[chrom] = seq_arr[chrom] + + self._dnase_allcelltypes = {} + for ct in self._all_celltypes: + dnase_filename = os.path.join(self._data_dir, '{}_dnase.npz'.format(ct)) + dnase_npz_file = np.load(dnase_filename) + self._dnase_allcelltypes[ct] = {} + for chrom in seq_bp: + self._dnase_allcelltypes[ct][chrom] = dnase_npz_file[chrom] + + # Read in metadata dataframe from training+validation data + train_chr = pd.read_csv(os.path.join(self._data_dir, 'labels/{}.train.labels.tsv.gz'.format(self._transcription_factor)), sep='\t') + val_chr = pd.read_csv(os.path.join(self._data_dir, 'labels/{}.val.labels.tsv.gz'.format(self._transcription_factor)), sep='\t') + training_df = train_chr[np.isin(train_chr['chr'], self._tr_chrs)] + val_df = val_chr[np.isin(val_chr['chr'], self._te_chrs)] + all_df = pd.concat([training_df, val_df]) + + # Filter by start/stop coordinate if needed + filter_msk = all_df['start'] >= 0 + filter_msk = all_df['start']%1000 == 0 + all_df = all_df[filter_msk] + + pd_list = [] + for ct in self._train_celltypes: + tc_chr = all_df[['chr', 'start', 'stop', ct]] + tc_chr.columns = ['chr', 'start', 'stop', 'y'] + tc_chr['celltype'] = ct + pd_list.append(tc_chr) + metadata_df = pd.concat(pd_list) + + # Get the y values, and remove ambiguous labels by default. + y_array = metadata_df['y'].replace({'U': 0, 'B': 1, 'A': -1}).values + non_ambig_mask = (y_array != -1) + metadata_df['y'] = y_array + self._metadata_df = metadata_df[non_ambig_mask] + self._y_array = torch.LongTensor(y_array[non_ambig_mask]) + + chr_ints = self._metadata_df['chr'].replace(dict( [(y, x) for x, y in enumerate(self._metadata_map['chr'])] )).values + celltype_ints = self._metadata_df['celltype'].replace(dict( [(y, x) for x, y in enumerate(self._metadata_map['celltype'])] )).values + self._metadata_array = torch.stack( + (torch.LongTensor(chr_ints), + torch.LongTensor(celltype_ints), + self._y_array), + dim=1) + + # Get the splits + # TODO Extract splits as encoded in split_scheme. Hardcoded here for now. + self._split_scheme = split_scheme + self._split_dict = { + 'train': 0, + 'val-id': 1, + 'test': 2, + 'val-ood': 3 + } + self._split_names = { + 'train': 'Train', + 'val-id': 'Validation (ID)', + 'test': 'Test', + 'val-ood': 'Validation (OOD)', + } + train_chr_mask = np.isin(self._metadata_df['chr'], self._tr_chrs) + val_chr_mask = np.isin(self._metadata_df['chr'], self._te_chrs) + train_celltype_mask = np.isin(self._metadata_df['celltype'], self._train_celltypes) + val_celltype_mask = np.isin(self._metadata_df['celltype'], self._val_celltype) + test_celltype_mask = np.isin(self._metadata_df['celltype'], self._test_celltype) + + split_array = -1*np.ones(self._metadata_df.shape[0]).astype(int) + split_array[np.logical_and(train_chr_mask, train_celltype_mask)] = self._split_dict['train'] + split_array[np.logical_and(val_chr_mask, test_celltype_mask)] = self._split_dict['test'] + # Validate using test chr, either using a designated validation cell line ('val-ood') or a training cell line ('val-id') + split_array[np.logical_and(val_chr_mask, val_celltype_mask)] = self._split_dict['val-ood'] + split_array[np.logical_and(val_chr_mask, train_celltype_mask)] = self._split_dict['val-id'] + if self._split_scheme=='standard': + self._metadata_df['split'] = split_array + self._split_array = split_array + else: + raise ValueError(f'Split scheme {self._split_scheme} not recognized') + self._eval_grouper = CombinatorialGrouper( + dataset=self, + groupby_fields=['celltype']) + self._metric = Auprc() + + super().__init__(root_dir, download, split_scheme) + + def get_input(self, idx): + """ + Returns x for a given idx. + Computes this from: + (1) sequence features in self._seq_bp + (2) DNase features in self._dnase_allcelltypes + (3) Metadata for the index (location along the genome with 1kb window width) + """ + this_metadata = self._metadata_df.iloc[idx, :] + flank_size = 400 + interval_start = this_metadata['start'] - flank_size + interval_end = this_metadata['stop'] + flank_size + dnase_this = _dnase_allcelltypes[this_metadata['celltype']][this_metadata['chr']][interval_start:interval_end] + seq_this = _seq_bp[this_metadata['chr']][interval_start:interval_end] + return np.column_stack([seq_this, dnase_this]) + + def eval(self, y_pred, y_true, metadata): + return self.standard_group_eval( + self._metric, + self._eval_grouper, + y_pred, y_true, metadata) diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index 08276aa9..6996cc15 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -5,7 +5,6 @@ from wilds.datasets.wilds_dataset import WILDSDataset from wilds.common.grouper import CombinatorialGrouper from wilds.common.metrics.eval_metric import Accuracy -from wilds.common.eval import standard_group_eval import IPython @@ -133,7 +132,7 @@ def get_input(self, idx): (3) Metadata for the index (location along the genome with 1kb window width) """ this_metadata = self._metadata_df.iloc[idx, :] - flank_size = 500 + flank_size = 400 interval_start = this_metadata['start'] - flank_size interval_end = this_metadata['stop'] + flank_size dnase_this = _dnase_allcelltypes[this_metadata['celltype']][this_metadata['chr']][interval_start:interval_end] @@ -141,7 +140,7 @@ def get_input(self, idx): return np.column_stack([seq_this, dnase_this]) def eval(self, y_pred, y_true, metadata): - return standard_group_eval( + return self.standard_group_eval( self._metric, self._eval_grouper, y_pred, y_true, metadata) diff --git a/wilds/version.py b/wilds/version.py index 3f7bf4a6..a35ec15c 100644 --- a/wilds/version.py +++ b/wilds/version.py @@ -1,27 +1,26 @@ -# Adapted from https://github.com/snap-stanford/ogb/blob/master/ogb/version.py - -import os -import logging -from threading import Thread - -__version__ = '1.0.0' - -try: - os.environ['OUTDATED_IGNORE'] = '1' - from outdated import check_outdated # noqa -except ImportError: - check_outdated = None - -def check(): - try: - is_outdated, latest = check_outdated('wilds', __version__) - if is_outdated: - logging.warning( - f'The WILDS package is out of date. Your version is ' - f'{__version__}, while the latest version is {latest}.') - except Exception: - pass - -if check_outdated is not None: - thread = Thread(target=check) - thread.start() +# Adapted from https://github.com/snap-stanford/ogb/blob/master/ogb/version.py + +import os +import logging +from threading import Thread + +__version__ = '1.0.0' + +try: + os.environ['OUTDATED_IGNORE'] = '1' + from outdated import check_outdated # noqa +except ImportError: + check_outdated = None + +def check(): + try: + is_outdated, latest = check_outdated('wilds', __version__) + if is_outdated: + logging.warning( + f'The WILDS package is out of date. Your version is {__version__}, while the latest version is {latest}.') + except Exception: + pass + +if check_outdated is not None: + thread = Thread(target=check) + thread.start() From dd03d7d583d240ae92558ce1ec2b9d356cde02c2 Mon Sep 17 00:00:00 2001 From: aikanor Date: Mon, 8 Feb 2021 17:40:41 -0800 Subject: [PATCH 007/244] integration 3/ --- sbox_run_expt.ipynb | 1032 ++++++++++++++++++++++++++ wilds/datasets/camelyon17_dataset.py | 138 ---- wilds/datasets/encodetfbs_dataset.py | 4 +- 3 files changed, 1033 insertions(+), 141 deletions(-) create mode 100644 sbox_run_expt.ipynb diff --git a/sbox_run_expt.ipynb b/sbox_run_expt.ipynb new file mode 100644 index 00000000..612397ce --- /dev/null +++ b/sbox_run_expt.ipynb @@ -0,0 +1,1032 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# run_expt.py contents" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "ename": "SyntaxError", + "evalue": "invalid syntax (version.py, line 20)", + "output_type": "error", + "traceback": [ + "\u001b[0;36m File \u001b[0;32m\"wilds/version.py\"\u001b[0;36m, line \u001b[0;32m20\u001b[0m\n\u001b[0;31m f'The WILDS package is out of date. Your version is {__version__}, while the latest version is {latest}.')\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m invalid syntax\n" + ] + } + ], + "source": [ + "import os, csv\n", + "import time\n", + "import argparse\n", + "import pandas as pd\n", + "import torch\n", + "import torch.nn as nn\n", + "import torchvision\n", + "import sys\n", + "from collections import defaultdict\n", + "\n", + "from wilds.common.data_loaders import get_train_loader, get_eval_loader\n", + "from wilds.common.grouper import CombinatorialGrouper\n", + "\n", + "from utils import set_seed, Logger, BatchLogger, log_config, ParseKwargs, load, initialize_wandb, log_group_data, parse_bool\n", + "from train import train, evaluate\n", + "from algorithms.initializer import initialize_algorithm\n", + "from transforms import initialize_transform\n", + "from configs.utils import populate_defaults\n", + "import configs.supported as supported" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Initialize dataset object" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "57.8772239685\n", + "66.8270189762\n" + ] + } + ], + "source": [ + "import numpy as np, pandas as pd, os, time, torch, torchvision\n", + "data_dir = '/oak/stanford/groups/akundaje/abalsubr/DREAM/wilds/codalab_archive/'\n", + "tf = 'MAX'\n", + "itime = time.time()\n", + "train_chr = pd.read_csv(os.path.join(data_dir, 'labels/{}.train.labels.tsv.gz'.format(tf)), sep='\\t')\n", + "print(time.time() - itime)\n", + "val_chr = pd.read_csv(os.path.join(data_dir, 'labels/{}.val.labels.tsv.gz'.format(tf)), sep='\\t')\n", + "print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']\n", + "val_celltype = ['A549']\n", + "test_celltype = ['GM12878']\n", + "all_celltypes = train_celltypes + val_celltype + test_celltype\n", + "\n", + "metadata_map = {}\n", + "metadata_map['chr'] = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX']\n", + "metadata_map['celltype'] = all_celltypes\n", + "\n", + "_split_dict = {\n", + " 'train': 0,\n", + " 'val-id': 1,\n", + " 'test': 2,\n", + " 'val-ood': 3\n", + "}\n", + "_split_names = {\n", + " 'train': 'Train',\n", + " 'val-id': 'Validation (ID)',\n", + " 'test': 'Test',\n", + " 'val-ood': 'Validation (OOD)'\n", + "}\n", + "_split_scheme = 'standard'" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "('H1-hESC', 25.299736976623535)\n", + "('HCT116', 49.68733310699463)\n", + "('HeLa-S3', 74.65905213356018)\n", + "('HepG2', 99.33112812042236)\n", + "('K562', 124.1327919960022)\n", + "('A549', 149.19999814033508)\n", + "('GM12878', 174.0277030467987)\n" + ] + } + ], + "source": [ + "itime = time.time()\n", + "sequence_filename = os.path.join(data_dir, 'sequence.npz')\n", + "seq_arr = np.load(sequence_filename)\n", + "print(time.time() - itime)\n", + "\n", + "itime = time.time()\n", + "_seq_bp = {}\n", + "for chrom in seq_arr:\n", + " _seq_bp[chrom] = seq_arr[chrom]\n", + " print(chrom, time.time() - itime)\n", + "itime = time.time()\n", + "_dnase_allcelltypes = {}\n", + "for ct in all_celltypes:\n", + " dnase_filename = os.path.join(data_dir, '{}_dnase.npz'.format(ct))\n", + " dnase_npz_file = np.load(dnase_filename)\n", + " _dnase_allcelltypes[ct] = {}\n", + " for chrom in _seq_bp:\n", + " _dnase_allcelltypes[ct][chrom] = dnase_npz_file[chrom]\n", + " print(ct, time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class Beagle2(nn.Module):\n", + " \"\"\"\n", + " Neural net models over genomic sequence.\n", + " Input:\n", + " - sequence_length: int (default 1000) \n", + " - Shape: (N, 5, sequence_length, 1) with batch size N.\n", + " \n", + " Output:\n", + " - prediction (Tensor): float torch tensor of shape (N, )\n", + " \n", + " TODO: Finish docstring.\n", + " \"\"\"\n", + " def __init__(self):\n", + " \"\"\"\n", + " Parameters\n", + " ----------\n", + " sequence_length : int\n", + " n_genomic_features : int\n", + " \"\"\"\n", + " super(Beagle2, self).__init__()\n", + "\n", + " self.dropout = 0.3\n", + " self.num_cell_types = 1\n", + " self.conv1 = nn.Conv2d(5, 300, (19, 1), stride = (1, 1), padding=(9,0))\n", + " self.conv2 = nn.Conv2d(300, 200, (11, 1), stride = (1, 1), padding = (5,0))\n", + " self.conv3 = nn.Conv2d(200, 200, (7, 1), stride = (1, 1), padding = (4,0))\n", + " self.bn1 = nn.BatchNorm2d(300)\n", + " self.bn2 = nn.BatchNorm2d(200)\n", + " self.bn3 = nn.BatchNorm2d(200)\n", + " self.maxpool1 = nn.MaxPool2d((3, 1))\n", + " self.maxpool2 = nn.MaxPool2d((4, 1))\n", + " self.maxpool3 = nn.MaxPool2d((4, 1))\n", + "\n", + " self.fc1 = nn.Linear(4200, 1000)\n", + " self.bn4 = nn.BatchNorm1d(1000)\n", + "\n", + " self.fc2 = nn.Linear(1000, 1000)\n", + " self.bn5 = nn.BatchNorm1d(1000)\n", + "\n", + " self.fc3 = nn.Linear(1000, self.num_cell_types)\n", + "\n", + " def forward(self, s):\n", + " s = s.permute(0, 2, 1).contiguous() # batch_size x 4 x 1000\n", + " s = s.view(-1, 5, 1000, 1) # batch_size x 4 x 1000 x 1 [4 channels]\n", + " s = self.maxpool1(F.relu(self.bn1(self.conv1(s)))) # batch_size x 300 x 333 x 1\n", + " s = self.maxpool2(F.relu(self.bn2(self.conv2(s)))) # batch_size x 200 x 83 x 1\n", + " s = self.maxpool3(F.relu(self.bn3(self.conv3(s)))) # batch_size x 200 x 21 x 1\n", + " s = s.view(-1, 4200)\n", + " conv_out = s\n", + "\n", + " s = F.dropout(F.relu(self.bn4(self.fc1(s))), p=self.dropout, training=self.training) # batch_size x 1000\n", + " #s = F.dropout(F.relu(self.bn5(self.fc2(s))), p=self.dropout, training=self.training) # batch_size x 1000\n", + " \n", + " \n", + " s = self.fc3(s)\n", + "\n", + " return s, conv_out\n", + "\n", + "\n", + "class DanQ(nn.Module):\n", + " def __init__(self, sequence_length, n_genomic_features):\n", + " \"\"\"\n", + " Parameters\n", + " ----------\n", + " sequence_length : int\n", + " Input sequence length\n", + " n_genomic_features : int\n", + " Total number of features to predict\n", + " \"\"\"\n", + " super(DanQ, self).__init__()\n", + " self.nnet = nn.Sequential(\n", + " nn.Conv1d(4, 320, kernel_size=26),\n", + " nn.ReLU(inplace=True),\n", + " nn.MaxPool1d(\n", + " kernel_size=13, stride=13),\n", + " nn.Dropout(0.2))\n", + "\n", + " self.bdlstm = nn.Sequential(\n", + " nn.LSTM(\n", + " 320, 320, num_layers=1, batch_first=True, bidirectional=True))\n", + "\n", + " self._n_channels = math.floor(\n", + " (sequence_length - 25) / 13)\n", + " self.classifier = nn.Sequential(\n", + " nn.Dropout(0.5),\n", + " nn.Linear(self._n_channels * 640, 925),\n", + " nn.ReLU(inplace=True),\n", + " nn.Linear(925, n_genomic_features),\n", + " nn.Sigmoid())\n", + "\n", + " def forward(self, x):\n", + " \"\"\"Forward propagation of a batch.\n", + " \"\"\"\n", + " out = self.nnet(x)\n", + " reshape_out = out.transpose(0, 1).transpose(0, 2)\n", + " out, _ = self.bdlstm(reshape_out)\n", + " out = out.transpose(0, 1)\n", + " reshape_out = out.contiguous().view(\n", + " out.size(0), 640 * self._n_channels)\n", + " predict = self.classifier(reshape_out)\n", + " return predict\n", + "\n", + "\n", + "class DeepSEA(nn.Module):\n", + " def __init__(self, sequence_length, n_genomic_features):\n", + " \"\"\"\n", + " Parameters\n", + " ----------\n", + " sequence_length : int\n", + " n_genomic_features : int\n", + " \"\"\"\n", + " super(DeepSEA, self).__init__()\n", + " conv_kernel_size = 8\n", + " pool_kernel_size = 4\n", + "\n", + " self.conv_net = nn.Sequential(\n", + " nn.Conv1d(4, 320, kernel_size=conv_kernel_size),\n", + " nn.ReLU(inplace=True),\n", + " nn.MaxPool1d(\n", + " kernel_size=pool_kernel_size, stride=pool_kernel_size),\n", + " nn.Dropout(p=0.2),\n", + "\n", + " nn.Conv1d(320, 480, kernel_size=conv_kernel_size),\n", + " nn.ReLU(inplace=True),\n", + " nn.MaxPool1d(\n", + " kernel_size=pool_kernel_size, stride=pool_kernel_size),\n", + " nn.Dropout(p=0.2),\n", + "\n", + " nn.Conv1d(480, 960, kernel_size=conv_kernel_size),\n", + " nn.ReLU(inplace=True),\n", + " nn.Dropout(p=0.5))\n", + "\n", + " reduce_by = conv_kernel_size - 1\n", + " pool_kernel_size = float(pool_kernel_size)\n", + " self.n_channels = int(\n", + " np.floor(\n", + " (np.floor(\n", + " (sequence_length - reduce_by) / pool_kernel_size)\n", + " - reduce_by) / pool_kernel_size)\n", + " - reduce_by)\n", + " self.classifier = nn.Sequential(\n", + " nn.Linear(960 * self.n_channels, n_genomic_features),\n", + " nn.ReLU(inplace=True),\n", + " nn.Linear(n_genomic_features, n_genomic_features),\n", + " nn.Sigmoid())\n", + "\n", + " def forward(self, x):\n", + " \"\"\"Forward propagation of a batch.\n", + " \"\"\"\n", + " out = self.conv_net(x)\n", + " reshape_out = out.view(out.size(0), 960 * self.n_channels)\n", + " predict = self.classifier(reshape_out)\n", + " return predict" + ] + }, + { + "cell_type": "code", + "execution_count": 78, + "metadata": {}, + "outputs": [], + "source": [ + "import math\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "class Beagle(nn.Module):\n", + " \"\"\"\n", + " Neural net models over genomic sequence.\n", + " Input:\n", + " - sequence_length: int (default 1000) \n", + " - Shape: (N, 5, sequence_length, 1) with batch size N.\n", + " \n", + " Output:\n", + " - prediction (Tensor): float torch tensor of shape (N, )\n", + " \n", + " TODO: Finish docstring.\n", + " \"\"\"\n", + " def __init__(self):\n", + " \"\"\"\n", + " Parameters\n", + " ----------\n", + " sequence_length : int\n", + " n_genomic_features : int\n", + " \"\"\"\n", + " super(Beagle, self).__init__()\n", + "\n", + " self.dropout = 0.3\n", + " self.num_cell_types = 1\n", + " self.conv1 = nn.Conv2d(5, 300, (19, 1), stride = (1, 1), padding=(9,0))\n", + " self.conv2 = nn.Conv2d(300, 200, (11, 1), stride = (1, 1), padding = (5,0))\n", + " self.conv3 = nn.Conv2d(200, 200, (7, 1), stride = (1, 1), padding = (4,0))\n", + " self.bn1 = nn.BatchNorm2d(300)\n", + " self.bn2 = nn.BatchNorm2d(200)\n", + " self.bn3 = nn.BatchNorm2d(200)\n", + " self.maxpool1 = nn.MaxPool2d((3, 1))\n", + " self.maxpool2 = nn.MaxPool2d((4, 1))\n", + " self.maxpool3 = nn.MaxPool2d((4, 1))\n", + "\n", + " self.fc1 = nn.Linear(4200, 1000)\n", + " self.bn4 = nn.BatchNorm1d(1000)\n", + "\n", + " self.fc2 = nn.Linear(1000, 1000)\n", + " self.bn5 = nn.BatchNorm1d(1000)\n", + "\n", + " self.fc3 = nn.Linear(1000, self.num_cell_types)\n", + "\n", + " def forward(self, s):\n", + " s = s.permute(0, 2, 1).contiguous() # batch_size x 5 x 1000\n", + " s = s.view(-1, 5, 1000, 1) # batch_size x 5 x 1000 x 1 [5 channels]\n", + " s = self.maxpool1(F.relu(self.bn1(self.conv1(s)))) # batch_size x 300 x 333 x 1\n", + " s = self.maxpool2(F.relu(self.bn2(self.conv2(s)))) # batch_size x 200 x 83 x 1\n", + " s = self.maxpool3(F.relu(self.bn3(self.conv3(s)))) # batch_size x 200 x 21 x 1\n", + " s = s.view(-1, 4200)\n", + " conv_out = s\n", + "\n", + " s = F.dropout(F.relu(self.bn4(self.fc1(s))), p=self.dropout, training=self.training) # batch_size x 1000\n", + " s = F.dropout(F.relu(self.bn5(self.fc2(s))), p=self.dropout, training=self.training) # batch_size x 1000\n", + " \n", + " s = self.fc3(s)\n", + "\n", + " return s, conv_out" + ] + }, + { + "cell_type": "code", + "execution_count": 86, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[('nnet.0.weight', 33280),\n", + " ('nnet.0.bias', 320),\n", + " ('bdlstm.0.weight_ih_l0', 409600),\n", + " ('bdlstm.0.weight_hh_l0', 409600),\n", + " ('bdlstm.0.bias_ih_l0', 1280),\n", + " ('bdlstm.0.bias_hh_l0', 1280),\n", + " ('bdlstm.0.weight_ih_l0_reverse', 409600),\n", + " ('bdlstm.0.weight_hh_l0_reverse', 409600),\n", + " ('bdlstm.0.bias_ih_l0_reverse', 1280),\n", + " ('bdlstm.0.bias_hh_l0_reverse', 1280),\n", + " ('classifier.1.weight', 592000),\n", + " ('classifier.1.bias', 925),\n", + " ('classifier.3.weight', 4625),\n", + " ('classifier.3.bias', 5)]" + ] + }, + "execution_count": 86, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def count_parameters(model):\n", + " return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", + "\n", + "model = Beagle2()\n", + "model = DanQ(50, 5)\n", + "\n", + "lst = [(x[0], x[1].numel()) for x in model.named_parameters()]\n", + "#np.sum([x[1] for x in lst])\n", + "count_parameters(model)\n", + "lst" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": {}, + "outputs": [ + { + "ename": "AttributeError", + "evalue": "'module' object has no attribute 'isin'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mtr_chrs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m'chr2'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr9'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr11'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mte_chrs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m'chr1'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr8'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr21'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mtraining_df\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'chr'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtr_chrs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mval_df\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mval_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mval_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'chr'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mte_chrs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mall_df\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtraining_df\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_df\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mAttributeError\u001b[0m: 'module' object has no attribute 'isin'" + ] + } + ], + "source": [ + "tr_chrs = ['chr2', 'chr9', 'chr11']\n", + "te_chrs = ['chr1', 'chr8', 'chr21']\n", + "training_df = train_chr[np.isin(train_chr['chr'], tr_chrs)]\n", + "val_df = val_chr[np.isin(val_chr['chr'], te_chrs)]\n", + "all_df = pd.concat([training_df, val_df])\n", + "\n", + "#filter_msk = all_df['start'] >= 0\n", + "filter_msk = all_df['start']%1000 == 0\n", + "all_df = all_df[filter_msk]" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1.12.1\n" + ] + } + ], + "source": [ + "print(np.__version__)" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/users/abalsubr/anaconda2/envs/scs3/lib/python3.6/site-packages/ipykernel_launcher.py:6: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy\n", + " \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1.659163236618042\n" + ] + } + ], + "source": [ + "itime = time.time()\n", + "pd_list = []\n", + "for ct in all_celltypes:\n", + " tc_chr = all_df[['chr', 'start', 'stop', ct]]\n", + " tc_chr.columns = ['chr', 'start', 'stop', 'y']\n", + " tc_chr['celltype'] = ct\n", + " pd_list.append(tc_chr)\n", + "metadata_df = pd.concat(pd_list)\n", + "print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3.0391879081726074\n" + ] + } + ], + "source": [ + "itime = time.time()\n", + "y_array = metadata_df['y'].replace({'U': 0, 'B': 1, 'A': -1}).values\n", + "non_ambig_mask = (y_array != -1)\n", + "metadata_df['y'] = y_array\n", + "_metadata_df = metadata_df[non_ambig_mask]\n", + "_y_array = torch.LongTensor(y_array[non_ambig_mask])\n", + "print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "12.390011310577393\n" + ] + } + ], + "source": [ + "itime = time.time()\n", + "chr_ints = _metadata_df['chr'].replace(dict( [(y, x) for x, y in enumerate(metadata_map['chr'])] )).values\n", + "celltype_ints = _metadata_df['celltype'].replace(dict( [(y, x) for x, y in enumerate(metadata_map['celltype'])] )).values\n", + "print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/users/abalsubr/anaconda2/envs/scs3/lib/python3.6/site-packages/ipykernel_launcher.py:12: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy\n", + " if sys.path[0] == '':\n" + ] + } + ], + "source": [ + "train_chr_mask = np.isin(_metadata_df['chr'], tr_chrs)\n", + "val_chr_mask = np.isin(_metadata_df['chr'], te_chrs)\n", + "train_celltype_mask = np.isin(_metadata_df['celltype'], train_celltypes)\n", + "val_celltype_mask = np.isin(_metadata_df['celltype'], val_celltype)\n", + "test_celltype_mask = np.isin(_metadata_df['celltype'], test_celltype)\n", + "\n", + "split_array = -1*np.ones(_metadata_df.shape[0]).astype(int)\n", + "split_array[np.logical_and(train_chr_mask, train_celltype_mask)] = _split_dict['train']\n", + "split_array[np.logical_and(val_chr_mask, test_celltype_mask)] = _split_dict['test']\n", + "split_array[np.logical_and(val_chr_mask, val_celltype_mask)] = _split_dict['val-ood']\n", + "split_array[np.logical_and(val_chr_mask, train_celltype_mask)] = _split_dict['val-id']\n", + "_metadata_df['split'] = split_array\n", + "_split_array = split_array" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# get_input (idx)" + ] + }, + { + "cell_type": "code", + "execution_count": 153, + "metadata": {}, + "outputs": [], + "source": [ + "idx = 3\n", + "this_metadata = _metadata_df.iloc[idx, :]\n", + "\n", + "itime = time.time()\n", + "flank_size = 400\n", + "interval_start = this_metadata['start'] - flank_size\n", + "interval_end = this_metadata['stop'] + flank_size\n", + "dnase_this = _dnase_allcelltypes[this_metadata['celltype']][this_metadata['chr']][interval_start:interval_end]\n", + "seq_this = _seq_bp[this_metadata['chr']][interval_start:interval_end]\n", + "data = np.column_stack([seq_this, dnase_this])\n", + "# print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 154, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "4600" + ] + }, + "execution_count": 154, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data.shape\n", + "interval_end\n", + "# itime = time.time()\n", + "# np.save(os.path.join(data_dir, 'stmp.npy'), sa)\n", + "# print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 78, + "metadata": {}, + "outputs": [ + { + "ename": "TypeError", + "evalue": "can't convert np.ndarray of type numpy.object_. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mitime\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m metadata_array = torch.stack(\n\u001b[0;32m----> 3\u001b[0;31m (torch.LongTensor(metadata_df['chr'].values), \n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mLongTensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmetadata_df\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'celltype'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m self._y_array),\n", + "\u001b[0;31mTypeError\u001b[0m: can't convert np.ndarray of type numpy.object_. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool." + ] + } + ], + "source": [ + "itime = time.time()\n", + "metadata_array = torch.stack(\n", + " (torch.LongTensor(chr_ints), \n", + " torch.LongTensor(celltype_ints), \n", + " _y_array),\n", + " dim=1)\n", + "print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 156, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name '_metadata_array' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0m_metadata_array\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mNameError\u001b[0m: name '_metadata_array' is not defined" + ] + } + ], + "source": [ + "_metadata_array" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from examples.models.model_attributes import model_attributes" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "ename": "ModuleNotFoundError", + "evalue": "No module named 'utils'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel_attributes\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mmodel_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mset_seed\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mLogger\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mCSVBatchLogger\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlog_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mParseKwargs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mload\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 21\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 22\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mdataset_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0moptimizer_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/dr_benchmark/examples/train.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0msave\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mIPython\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'utils'" + ] + } + ], + "source": [ + "import os, csv\n", + "import time\n", + "import argparse\n", + "import IPython\n", + "import pandas as pd\n", + "import torch\n", + "import torch.nn as nn\n", + "import torchvision\n", + "import sys\n", + "from collections import defaultdict\n", + "\n", + "# TODO: Replace this once we make wilds into an installed package\n", + "sys.path.insert(1, os.path.join(sys.path[0], '..'))\n", + "\n", + "from wilds.common.data_loaders import get_train_loader, get_eval_loader\n", + "from wilds.common.grouper import CombinatorialGrouper\n", + "from wilds.common.utils import get_counts\n", + "\n", + "from examples.models.model_attributes import model_attributes\n", + "from examples.utils import set_seed, Logger, CSVBatchLogger, log_args, ParseKwargs, load\n", + "from examples.train import train\n", + "from examples.data import dataset_attributes\n", + "from examples.optimizer import optimizer_attributes\n", + "from examples.scheduler import scheduler_attributes\n", + "from examples.loss import losses\n", + "from examples.utils import log_group_data\n", + "from examples.algorithms.constructors import algorithm_constructors\n", + "\n", + "\n", + "def initialize_algorithm(args, datasets, train_grouper):\n", + " train_dataset = datasets['train']['dataset']\n", + " train_loader = datasets['train']['loader']\n", + "\n", + " # Configure the final layer of the networks used\n", + " # The code below are defaults. Edit this if you need special config for your model.\n", + " if (train_dataset.is_classification) and (train_dataset.y_size == 1):\n", + " # For single-task classification, we have one output per class\n", + " d_out = train_dataset.n_classes\n", + " elif (not train_dataset.is_classification):\n", + " # For regression, we have one output per target dimension\n", + " d_out = train_dataset.y_size\n", + " else:\n", + " # TODO: Handle dataset-specific multi-task stuff here, e.g., for OGB\n", + " pass\n", + "\n", + " # Sanity checking input args\n", + " if args.algorithm == 'groupDRO':\n", + " assert args.train_loader_kwargs['uniform_over_groups']\n", + " elif args.algorithm in ['deepCORAL', 'IRM']:\n", + " assert args.train_loader == 'group'\n", + " assert args.train_loader_kwargs['uniform_over_groups']\n", + " assert args.train_loader_kwargs['distinct_groups']\n", + "\n", + " # Other config\n", + " n_train_steps = len(train_loader) * args.n_epochs\n", + " prediction_fn = dataset_attributes[args.dataset]['prediction_fn']\n", + " loss = losses[args.loss_function]\n", + " metric_constructor = dataset_attributes[args.dataset]['metric']\n", + " train_g = train_grouper.metadata_to_group(train_dataset.metadata_array)\n", + " is_group_in_train = get_counts(train_g, train_grouper.n_groups) > 0\n", + " algorithm_constructor = algorithm_constructors[args.algorithm]\n", + " algorithm = algorithm_constructor(\n", + " args=args,\n", + " d_out=d_out,\n", + " grouper=train_grouper,\n", + " prediction_fn=prediction_fn,\n", + " loss=loss,\n", + " metric_constructor=metric_constructor,\n", + " n_train_steps=n_train_steps,\n", + " is_group_in_train=is_group_in_train)\n", + " return algorithm" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "parser = argparse.ArgumentParser()\n", + "\n", + "# Dataset\n", + "parser.add_argument('-d', '--dataset', choices=dataset_attributes.keys(), required=True)\n", + "parser.add_argument('--split_scheme', default='standard',\n", + " help='Identifies how the train/val/test split is constructed. Choices are dataset-specific.')\n", + "parser.add_argument('--root_dir', default=None, required=True,\n", + " help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).')\n", + "parser.add_argument('--download', default=False, action='store_true',\n", + " help='If true, tries to downloads the dataset if it does not exist in root_dir.')\n", + "parser.add_argument('--frac', type=float, default=1.0,\n", + " help='Convenience parameter that scales all dataset splits down to the specified fraction, for development purposes.')\n", + "\n", + "# Loaders\n", + "parser.add_argument('--train_loader', choices=['standard', 'group'], default='standard')\n", + "parser.add_argument('--train_loader_kwargs', nargs='*', action=ParseKwargs, default={})\n", + "parser.add_argument('--eval_loader', choices=['standard', 'group'], default='standard')\n", + "parser.add_argument('--batch_size', type=int, default=32)\n", + "\n", + "# Model\n", + "parser.add_argument(\n", + " '--model',\n", + " choices=model_attributes.keys(),\n", + " default='resnet50')\n", + "parser.add_argument('--model_kwargs', nargs='*', action=ParseKwargs, default={},\n", + " help='keyword arguments for model initialization passed as key1=value1 key2=value2')\n", + "parser.add_argument('--train_from_scratch', action='store_true', default=False)\n", + "\n", + "# Algorithm and objective\n", + "parser.add_argument('--algorithm', required=True, choices=algorithm_constructors.keys())\n", + "parser.add_argument('--algorithm_kwargs', nargs='*', action=ParseKwargs, default={})\n", + "parser.add_argument('--groupby_fields', nargs='+', default=None)\n", + "parser.add_argument('--loss_function', required=True, choices = losses.keys()) #TODO: make default\n", + "parser.add_argument('--val_metric', default=None)\n", + "\n", + "# Optimization\n", + "parser.add_argument('--n_epochs', type=int, default=4)\n", + "parser.add_argument('--optimizer', default=None, choices=optimizer_attributes.keys())\n", + "parser.add_argument('--lr', type=float, required=True)\n", + "parser.add_argument('--weight_decay', type=float, required=True)\n", + "parser.add_argument('--optimizer_kwargs', nargs='*', action=ParseKwargs, default={})\n", + "parser.add_argument('--scheduler', default=None, choices=scheduler_attributes.keys())\n", + "parser.add_argument('--scheduler_kwargs', nargs='*', action=ParseKwargs, default={})\n", + "parser.add_argument('--scheduler_metric_split', choices=['train', 'val'], default='val')\n", + "parser.add_argument('--scheduler_metric_name')\n", + "\n", + "# Evaluation\n", + "parser.add_argument('--evaluate_all_splits', action='store_true', default=False)\n", + "parser.add_argument('--additional_eval_splits', nargs='+', default=[])\n", + "\n", + "# Misc\n", + "parser.add_argument('--device', default='cuda')\n", + "parser.add_argument('--seed', type=int, default=0)\n", + "parser.add_argument('--log_dir', default='./logs')\n", + "parser.add_argument('--log_every', default=50, type=int)\n", + "parser.add_argument('--save_step', type=int, default=None)\n", + "parser.add_argument('--save_best', action='store_true', default=False)\n", + "parser.add_argument('--save_last', action='store_true', default=False)\n", + "parser.add_argument('--save_outputs', action='store_true', default=False)\n", + "parser.add_argument('--no_group_logging', action='store_true', default=False)\n", + "\n", + "parser.add_argument('--resume', default=False, action='store_true')\n", + "\n", + "args = parser.parse_args()\n", + "\n", + "# Set defaults\n", + "if args.groupby_fields is None:\n", + " args.no_group_logging = True\n", + "if args.val_metric is None:\n", + " args.val_metric = dataset_attributes[args.dataset]['val_metric']\n", + "\n", + "## Initialize logs\n", + "if os.path.exists(args.log_dir) and args.resume:\n", + " resume=True\n", + " mode='a'\n", + "else:\n", + " resume=False\n", + " mode='w'\n", + "if not os.path.exists(args.log_dir):\n", + " os.makedirs(args.log_dir)\n", + "logger = Logger(os.path.join(args.log_dir, 'log.txt'), mode)\n", + "\n", + "# Record args\n", + "log_args(args, logger)\n", + "\n", + "# Set random seed\n", + "set_seed(args.seed)\n", + "\n", + "# Data\n", + "full_dataset = dataset_attributes[args.dataset]['constructor'](\n", + " root_dir=args.root_dir,\n", + " download=args.download,\n", + " split_scheme=args.split_scheme)\n", + "\n", + "# To implement data augmentation (i.e., have different transforms\n", + "# at training time vs. test time), modify these two lines:\n", + "train_transform = dataset_attributes[args.dataset]['transform'](args.model)\n", + "eval_transform = dataset_attributes[args.dataset]['transform'](args.model)\n", + "\n", + "train_grouper = CombinatorialGrouper(\n", + " dataset=full_dataset,\n", + " groupby_fields=args.groupby_fields)\n", + "\n", + "datasets = defaultdict(dict)\n", + "for split in full_dataset.split_dict.keys():\n", + " if split=='train':\n", + " transform = train_transform\n", + " verbose = True\n", + " elif split == 'val':\n", + " transform = eval_transform\n", + " verbose = True\n", + " else:\n", + " transform = eval_transform\n", + " verbose = False\n", + " # Get subset\n", + " datasets[split]['dataset'] = full_dataset.get_subset(\n", + " split,\n", + " frac=args.frac,\n", + " transform=transform)\n", + "\n", + " # Get loader\n", + " shared_loader_kwargs = {\n", + " 'num_workers': 4,\n", + " 'pin_memory': True,\n", + " 'batch_size': args.batch_size,\n", + " 'collate_fn': dataset_attributes[args.dataset]['collate']\n", + " }\n", + "\n", + " if split == 'train':\n", + " datasets[split]['loader'] = get_train_loader(\n", + " loader=args.train_loader,\n", + " dataset=datasets[split]['dataset'],\n", + " grouper=train_grouper,\n", + " train_loader_kwargs=args.train_loader_kwargs,\n", + " **shared_loader_kwargs)\n", + " else:\n", + " datasets[split]['loader'] = get_eval_loader(\n", + " loader=args.eval_loader,\n", + " dataset=datasets[split]['dataset'],\n", + " grouper=train_grouper,\n", + " **shared_loader_kwargs)\n", + "\n", + " # Set fields\n", + " datasets[split]['split'] = split\n", + " datasets[split]['name'] = full_dataset.split_names[split]\n", + " datasets[split]['verbose'] = verbose\n", + " # Loggers\n", + " datasets[split]['eval_logger'] = CSVBatchLogger(\n", + " os.path.join(args.log_dir, f'{split}_eval.csv'), mode=mode)\n", + " datasets[split]['algo_logger'] = CSVBatchLogger(\n", + " os.path.join(args.log_dir, f'{split}_algo.csv'), mode=mode)\n", + "\n", + "# Logging dataset info\n", + "if args.no_group_logging and full_dataset.is_classification and full_dataset.y_size==1:\n", + " log_grouper = CombinatorialGrouper(\n", + " dataset=full_dataset,\n", + " groupby_fields=['y'])\n", + "elif args.no_group_logging:\n", + " log_grouper = None\n", + "else:\n", + " log_grouper = train_grouper\n", + "log_group_data(args, datasets, log_grouper, logger)\n", + "\n", + "## Initialize algorithm\n", + "algorithm = initialize_algorithm(args, datasets, train_grouper)\n", + "\n", + "## Load saved results if resuming\n", + "if resume:\n", + " save_path = os.path.join(args.log_dir, 'last_model.pth')\n", + " prev_epoch, best_val_metric = load(algorithm, save_path)\n", + " epoch_offset = prev_epoch + 1\n", + "else:\n", + " epoch_offset=0\n", + " best_val_metric=None\n", + "\n", + "train(algorithm,\n", + " datasets,\n", + " logger,\n", + " args,\n", + " epoch_offset=epoch_offset,\n", + " best_val_metric=best_val_metric)\n", + "\n", + "logger.close()\n", + "for split in datasets:\n", + " datasets[split]['eval_logger'].close()\n", + " datasets[split]['algo_logger'].close()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 2", + "language": "python", + "name": "python2" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/wilds/datasets/camelyon17_dataset.py b/wilds/datasets/camelyon17_dataset.py index 691eac5d..d170ddf2 100644 --- a/wilds/datasets/camelyon17_dataset.py +++ b/wilds/datasets/camelyon17_dataset.py @@ -139,141 +139,3 @@ def eval(self, y_pred, y_true, metadata): self._metric, self._eval_grouper, y_pred, y_true, metadata) - - -class EncodeTFBSDataset(WILDSDataset): - """ - EncodeTFBS dataset - Website: https://www.synapse.org/#!Synapse:syn6131484 - """ - - def __init__(self, root_dir, download, split_scheme): - self._dataset_name = 'encodeTFBS' - self._download_url = 'https://worksheets.codalab.org/rest/bundles/0x8b3255e21e164cd98d3aeec09cd0bc26/contents/blob/' - self._data_dir = self.initialize_data_dir(root_dir, download) - self._y_size = 1 - self._n_classes = 2 - - self._tr_chrs = ['chr2', 'chr9', 'chr11'] - self._te_chrs = ['chr1', 'chr8', 'chr21'] - self._transcription_factor = 'MAX' - self._train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562'] - self._val_celltype = ['A549'] - self._test_celltype = ['GM12878'] - self._all_celltypes = self._train_celltypes + self._val_celltype + self._test_celltype - - self._metadata_fields = ['chr', 'celltype', 'y'] - self._metadata_map = {} - self._metadata_map['chr'] = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX'] - self._metadata_map['celltype'] = self._all_celltypes - - # Load sequence and DNase features - sequence_filename = os.path.join(self._data_dir, 'sequence.npz') - seq_arr = np.load(sequence_filename) - self._seq_bp = {} - for chrom in seq_arr: - self._seq_bp[chrom] = seq_arr[chrom] - - self._dnase_allcelltypes = {} - for ct in self._all_celltypes: - dnase_filename = os.path.join(self._data_dir, '{}_dnase.npz'.format(ct)) - dnase_npz_file = np.load(dnase_filename) - self._dnase_allcelltypes[ct] = {} - for chrom in seq_bp: - self._dnase_allcelltypes[ct][chrom] = dnase_npz_file[chrom] - - # Read in metadata dataframe from training+validation data - train_chr = pd.read_csv(os.path.join(self._data_dir, 'labels/{}.train.labels.tsv.gz'.format(self._transcription_factor)), sep='\t') - val_chr = pd.read_csv(os.path.join(self._data_dir, 'labels/{}.val.labels.tsv.gz'.format(self._transcription_factor)), sep='\t') - training_df = train_chr[np.isin(train_chr['chr'], self._tr_chrs)] - val_df = val_chr[np.isin(val_chr['chr'], self._te_chrs)] - all_df = pd.concat([training_df, val_df]) - - # Filter by start/stop coordinate if needed - filter_msk = all_df['start'] >= 0 - filter_msk = all_df['start']%1000 == 0 - all_df = all_df[filter_msk] - - pd_list = [] - for ct in self._train_celltypes: - tc_chr = all_df[['chr', 'start', 'stop', ct]] - tc_chr.columns = ['chr', 'start', 'stop', 'y'] - tc_chr['celltype'] = ct - pd_list.append(tc_chr) - metadata_df = pd.concat(pd_list) - - # Get the y values, and remove ambiguous labels by default. - y_array = metadata_df['y'].replace({'U': 0, 'B': 1, 'A': -1}).values - non_ambig_mask = (y_array != -1) - metadata_df['y'] = y_array - self._metadata_df = metadata_df[non_ambig_mask] - self._y_array = torch.LongTensor(y_array[non_ambig_mask]) - - chr_ints = self._metadata_df['chr'].replace(dict( [(y, x) for x, y in enumerate(self._metadata_map['chr'])] )).values - celltype_ints = self._metadata_df['celltype'].replace(dict( [(y, x) for x, y in enumerate(self._metadata_map['celltype'])] )).values - self._metadata_array = torch.stack( - (torch.LongTensor(chr_ints), - torch.LongTensor(celltype_ints), - self._y_array), - dim=1) - - # Get the splits - # TODO Extract splits as encoded in split_scheme. Hardcoded here for now. - self._split_scheme = split_scheme - self._split_dict = { - 'train': 0, - 'val-id': 1, - 'test': 2, - 'val-ood': 3 - } - self._split_names = { - 'train': 'Train', - 'val-id': 'Validation (ID)', - 'test': 'Test', - 'val-ood': 'Validation (OOD)', - } - train_chr_mask = np.isin(self._metadata_df['chr'], self._tr_chrs) - val_chr_mask = np.isin(self._metadata_df['chr'], self._te_chrs) - train_celltype_mask = np.isin(self._metadata_df['celltype'], self._train_celltypes) - val_celltype_mask = np.isin(self._metadata_df['celltype'], self._val_celltype) - test_celltype_mask = np.isin(self._metadata_df['celltype'], self._test_celltype) - - split_array = -1*np.ones(self._metadata_df.shape[0]).astype(int) - split_array[np.logical_and(train_chr_mask, train_celltype_mask)] = self._split_dict['train'] - split_array[np.logical_and(val_chr_mask, test_celltype_mask)] = self._split_dict['test'] - # Validate using test chr, either using a designated validation cell line ('val-ood') or a training cell line ('val-id') - split_array[np.logical_and(val_chr_mask, val_celltype_mask)] = self._split_dict['val-ood'] - split_array[np.logical_and(val_chr_mask, train_celltype_mask)] = self._split_dict['val-id'] - if self._split_scheme=='standard': - self._metadata_df['split'] = split_array - self._split_array = split_array - else: - raise ValueError(f'Split scheme {self._split_scheme} not recognized') - self._eval_grouper = CombinatorialGrouper( - dataset=self, - groupby_fields=['celltype']) - self._metric = Auprc() - - super().__init__(root_dir, download, split_scheme) - - def get_input(self, idx): - """ - Returns x for a given idx. - Computes this from: - (1) sequence features in self._seq_bp - (2) DNase features in self._dnase_allcelltypes - (3) Metadata for the index (location along the genome with 1kb window width) - """ - this_metadata = self._metadata_df.iloc[idx, :] - flank_size = 400 - interval_start = this_metadata['start'] - flank_size - interval_end = this_metadata['stop'] + flank_size - dnase_this = _dnase_allcelltypes[this_metadata['celltype']][this_metadata['chr']][interval_start:interval_end] - seq_this = _seq_bp[this_metadata['chr']][interval_start:interval_end] - return np.column_stack([seq_this, dnase_this]) - - def eval(self, y_pred, y_true, metadata): - return self.standard_group_eval( - self._metric, - self._eval_grouper, - y_pred, y_true, metadata) diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index 6996cc15..062e468a 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -4,9 +4,7 @@ import numpy as np from wilds.datasets.wilds_dataset import WILDSDataset from wilds.common.grouper import CombinatorialGrouper -from wilds.common.metrics.eval_metric import Accuracy - -import IPython +from wilds.common.metrics.all_metrics import Accuracy class EncodeTFBSDataset(WILDSDataset): """ From e1c6e8bcdfa0e59fc9780c406d92ac8ed491fe10 Mon Sep 17 00:00:00 2001 From: aikanor Date: Mon, 8 Feb 2021 17:43:25 -0800 Subject: [PATCH 008/244] integration revert 4/ --- wilds/datasets/camelyon17_dataset.py | 282 +++++++++++++-------------- 1 file changed, 141 insertions(+), 141 deletions(-) diff --git a/wilds/datasets/camelyon17_dataset.py b/wilds/datasets/camelyon17_dataset.py index d170ddf2..0a76f615 100644 --- a/wilds/datasets/camelyon17_dataset.py +++ b/wilds/datasets/camelyon17_dataset.py @@ -1,141 +1,141 @@ -import os -import torch -import pandas as pd -from PIL import Image -import numpy as np -from wilds.datasets.wilds_dataset import WILDSDataset -from wilds.common.grouper import CombinatorialGrouper -from wilds.common.metrics.all_metrics import Accuracy - -class Camelyon17Dataset(WILDSDataset): - """ - The CAMELYON17-wilds histopathology dataset. - This is a modified version of the original CAMELYON17 dataset. - - Supported `split_scheme`: - 'official' or 'in-dist' - - Input (x): - 96x96 image patches extracted from histopathology slides. - - Label (y): - y is binary. It is 1 if the central 32x32 region contains any tumor tissue, and 0 otherwise. - - Metadata: - Each patch is annotated with the ID of the hospital it came from (integer from 0 to 4) - and the slide it came from (integer from 0 to 49). - - Website: - https://camelyon17.grand-challenge.org/ - - Original publication: - @article{bandi2018detection, - title={From detection of individual metastases to classification of lymph node status at the patient level: the camelyon17 challenge}, - author={Bandi, Peter and Geessink, Oscar and Manson, Quirine and Van Dijk, Marcory and Balkenhol, Maschenka and Hermsen, Meyke and Bejnordi, Babak Ehteshami and Lee, Byungjae and Paeng, Kyunghyun and Zhong, Aoxiao and others}, - journal={IEEE transactions on medical imaging}, - volume={38}, - number={2}, - pages={550--560}, - year={2018}, - publisher={IEEE} - } - - License: - This dataset is in the public domain and is distributed under CC0. - https://creativecommons.org/publicdomain/zero/1.0/ - """ - - def __init__(self, root_dir='data', download=False, split_scheme='official'): - self._dataset_name = 'camelyon17' - self._version = '1.0' - self._download_url = 'https://worksheets.codalab.org/rest/bundles/0xe45e15f39fb54e9d9e919556af67aabe/contents/blob/' - self._compressed_size = 10_658_709_504 - self._data_dir = self.initialize_data_dir(root_dir, download) - self._original_resolution = (96,96) - - # Read in metadata - self._metadata_df = pd.read_csv( - os.path.join(self._data_dir, 'metadata.csv'), - index_col=0, - dtype={'patient': 'str'}) - - # Get the y values - self._y_array = torch.LongTensor(self._metadata_df['tumor'].values) - self._y_size = 1 - self._n_classes = 2 - - # Get filenames - self._input_array = [ - f'patches/patient_{patient}_node_{node}/patch_patient_{patient}_node_{node}_x_{x}_y_{y}.png' - for patient, node, x, y in - self._metadata_df.loc[:, ['patient', 'node', 'x_coord', 'y_coord']].itertuples(index=False, name=None)] - - # Extract splits - # Note that the hospital numbering here is different from what's in the paper, - # where to avoid confusing readers we used a 1-indexed scheme and just labeled the test hospital as 5. - # Here, the numbers are 0-indexed. - test_center = 2 - val_center = 1 - - self._split_dict = { - 'train': 0, - 'id_val': 1, - 'test': 2, - 'val': 3 - } - self._split_names = { - 'train': 'Train', - 'id_val': 'Validation (ID)', - 'test': 'Test', - 'val': 'Validation (OOD)', - } - centers = self._metadata_df['center'].values.astype('long') - num_centers = int(np.max(centers)) + 1 - val_center_mask = (self._metadata_df['center'] == val_center) - test_center_mask = (self._metadata_df['center'] == test_center) - self._metadata_df.loc[val_center_mask, 'split'] = self.split_dict['val'] - self._metadata_df.loc[test_center_mask, 'split'] = self.split_dict['test'] - - self._split_scheme = split_scheme - if self._split_scheme == 'official': - pass - elif self._split_scheme == 'in-dist': - # For the in-distribution oracle, - # we move slide 23 (corresponding to patient 042, node 3 in the original dataset) - # from the test set to the training set - slide_mask = (self._metadata_df['slide'] == 23) - self._metadata_df.loc[slide_mask, 'split'] = self.split_dict['train'] - else: - raise ValueError(f'Split scheme {self._split_scheme} not recognized') - self._split_array = self._metadata_df['split'].values - - self._metadata_array = torch.stack( - (torch.LongTensor(centers), - torch.LongTensor(self._metadata_df['slide'].values), - self._y_array), - dim=1) - self._metadata_fields = ['hospital', 'slide', 'y'] - - self._eval_grouper = CombinatorialGrouper( - dataset=self, - groupby_fields=['slide']) - - self._metric = Accuracy() - - super().__init__(root_dir, download, split_scheme) - - def get_input(self, idx): - """ - Returns x for a given idx. - """ - img_filename = os.path.join( - self.data_dir, - self._input_array[idx]) - x = Image.open(img_filename).convert('RGB') - return x - - def eval(self, y_pred, y_true, metadata): - return self.standard_group_eval( - self._metric, - self._eval_grouper, - y_pred, y_true, metadata) +import os +import torch +import pandas as pd +from PIL import Image +import numpy as np +from wilds.datasets.wilds_dataset import WILDSDataset +from wilds.common.grouper import CombinatorialGrouper +from wilds.common.metrics.all_metrics import Accuracy + +class Camelyon17Dataset(WILDSDataset): + """ + The CAMELYON17-wilds histopathology dataset. + This is a modified version of the original CAMELYON17 dataset. + + Supported `split_scheme`: + 'official' or 'in-dist' + + Input (x): + 96x96 image patches extracted from histopathology slides. + + Label (y): + y is binary. It is 1 if the central 32x32 region contains any tumor tissue, and 0 otherwise. + + Metadata: + Each patch is annotated with the ID of the hospital it came from (integer from 0 to 4) + and the slide it came from (integer from 0 to 49). + + Website: + https://camelyon17.grand-challenge.org/ + + Original publication: + @article{bandi2018detection, + title={From detection of individual metastases to classification of lymph node status at the patient level: the camelyon17 challenge}, + author={Bandi, Peter and Geessink, Oscar and Manson, Quirine and Van Dijk, Marcory and Balkenhol, Maschenka and Hermsen, Meyke and Bejnordi, Babak Ehteshami and Lee, Byungjae and Paeng, Kyunghyun and Zhong, Aoxiao and others}, + journal={IEEE transactions on medical imaging}, + volume={38}, + number={2}, + pages={550--560}, + year={2018}, + publisher={IEEE} + } + + License: + This dataset is in the public domain and is distributed under CC0. + https://creativecommons.org/publicdomain/zero/1.0/ + """ + + def __init__(self, root_dir='data', download=False, split_scheme='official'): + self._dataset_name = 'camelyon17' + self._version = '1.0' + self._download_url = 'https://worksheets.codalab.org/rest/bundles/0xe45e15f39fb54e9d9e919556af67aabe/contents/blob/' + self._compressed_size = 10_658_709_504 + self._data_dir = self.initialize_data_dir(root_dir, download) + self._original_resolution = (96,96) + + # Read in metadata + self._metadata_df = pd.read_csv( + os.path.join(self._data_dir, 'metadata.csv'), + index_col=0, + dtype={'patient': 'str'}) + + # Get the y values + self._y_array = torch.LongTensor(self._metadata_df['tumor'].values) + self._y_size = 1 + self._n_classes = 2 + + # Get filenames + self._input_array = [ + f'patches/patient_{patient}_node_{node}/patch_patient_{patient}_node_{node}_x_{x}_y_{y}.png' + for patient, node, x, y in + self._metadata_df.loc[:, ['patient', 'node', 'x_coord', 'y_coord']].itertuples(index=False, name=None)] + + # Extract splits + # Note that the hospital numbering here is different from what's in the paper, + # where to avoid confusing readers we used a 1-indexed scheme and just labeled the test hospital as 5. + # Here, the numbers are 0-indexed. + test_center = 2 + val_center = 1 + + self._split_dict = { + 'train': 0, + 'id_val': 1, + 'test': 2, + 'val': 3 + } + self._split_names = { + 'train': 'Train', + 'id_val': 'Validation (ID)', + 'test': 'Test', + 'val': 'Validation (OOD)', + } + centers = self._metadata_df['center'].values.astype('long') + num_centers = int(np.max(centers)) + 1 + val_center_mask = (self._metadata_df['center'] == val_center) + test_center_mask = (self._metadata_df['center'] == test_center) + self._metadata_df.loc[val_center_mask, 'split'] = self.split_dict['val'] + self._metadata_df.loc[test_center_mask, 'split'] = self.split_dict['test'] + + self._split_scheme = split_scheme + if self._split_scheme == 'official': + pass + elif self._split_scheme == 'in-dist': + # For the in-distribution oracle, + # we move slide 23 (corresponding to patient 042, node 3 in the original dataset) + # from the test set to the training set + slide_mask = (self._metadata_df['slide'] == 23) + self._metadata_df.loc[slide_mask, 'split'] = self.split_dict['train'] + else: + raise ValueError(f'Split scheme {self._split_scheme} not recognized') + self._split_array = self._metadata_df['split'].values + + self._metadata_array = torch.stack( + (torch.LongTensor(centers), + torch.LongTensor(self._metadata_df['slide'].values), + self._y_array), + dim=1) + self._metadata_fields = ['hospital', 'slide', 'y'] + + self._eval_grouper = CombinatorialGrouper( + dataset=self, + groupby_fields=['slide']) + + self._metric = Accuracy() + + super().__init__(root_dir, download, split_scheme) + + def get_input(self, idx): + """ + Returns x for a given idx. + """ + img_filename = os.path.join( + self.data_dir, + self._input_array[idx]) + x = Image.open(img_filename).convert('RGB') + return x + + def eval(self, y_pred, y_true, metadata): + return self.standard_group_eval( + self._metric, + self._eval_grouper, + y_pred, y_true, metadata) From bc6bf5161a8b181a7c5db0b38af69fbcc799811e Mon Sep 17 00:00:00 2001 From: aikanor Date: Mon, 8 Feb 2021 17:44:44 -0800 Subject: [PATCH 009/244] integration revert 5/ --- wilds/version.py | 53 ++++++++++++++++++++++++------------------------ 1 file changed, 27 insertions(+), 26 deletions(-) diff --git a/wilds/version.py b/wilds/version.py index a35ec15c..3f7bf4a6 100644 --- a/wilds/version.py +++ b/wilds/version.py @@ -1,26 +1,27 @@ -# Adapted from https://github.com/snap-stanford/ogb/blob/master/ogb/version.py - -import os -import logging -from threading import Thread - -__version__ = '1.0.0' - -try: - os.environ['OUTDATED_IGNORE'] = '1' - from outdated import check_outdated # noqa -except ImportError: - check_outdated = None - -def check(): - try: - is_outdated, latest = check_outdated('wilds', __version__) - if is_outdated: - logging.warning( - f'The WILDS package is out of date. Your version is {__version__}, while the latest version is {latest}.') - except Exception: - pass - -if check_outdated is not None: - thread = Thread(target=check) - thread.start() +# Adapted from https://github.com/snap-stanford/ogb/blob/master/ogb/version.py + +import os +import logging +from threading import Thread + +__version__ = '1.0.0' + +try: + os.environ['OUTDATED_IGNORE'] = '1' + from outdated import check_outdated # noqa +except ImportError: + check_outdated = None + +def check(): + try: + is_outdated, latest = check_outdated('wilds', __version__) + if is_outdated: + logging.warning( + f'The WILDS package is out of date. Your version is ' + f'{__version__}, while the latest version is {latest}.') + except Exception: + pass + +if check_outdated is not None: + thread = Thread(target=check) + thread.start() From a984ad0136c54b953f24ba1738683f89d72a4d39 Mon Sep 17 00:00:00 2001 From: aikanor Date: Tue, 9 Feb 2021 00:23:04 -0800 Subject: [PATCH 010/244] integration 6/ --- .../sbox_run_expt.ipynb | 366 +++++++++++++++++- 1 file changed, 349 insertions(+), 17 deletions(-) rename sbox_run_expt.ipynb => examples/sbox_run_expt.ipynb (74%) diff --git a/sbox_run_expt.ipynb b/examples/sbox_run_expt.ipynb similarity index 74% rename from sbox_run_expt.ipynb rename to examples/sbox_run_expt.ipynb index 612397ce..e74eeefb 100644 --- a/sbox_run_expt.ipynb +++ b/examples/sbox_run_expt.ipynb @@ -11,16 +11,7 @@ "cell_type": "code", "execution_count": 1, "metadata": {}, - "outputs": [ - { - "ename": "SyntaxError", - "evalue": "invalid syntax (version.py, line 20)", - "output_type": "error", - "traceback": [ - "\u001b[0;36m File \u001b[0;32m\"wilds/version.py\"\u001b[0;36m, line \u001b[0;32m20\u001b[0m\n\u001b[0;31m f'The WILDS package is out of date. Your version is {__version__}, while the latest version is {latest}.')\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m invalid syntax\n" - ] - } - ], + "outputs": [], "source": [ "import os, csv\n", "import time\n", @@ -43,12 +34,353 @@ "import configs.supported as supported" ] }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "''' set default hyperparams in default_hyperparams.py '''\n", + "parser = argparse.ArgumentParser()\n", + "\n", + "# Required arguments\n", + "parser.add_argument('-d', '--dataset', choices=supported.datasets, required=True)\n", + "parser.add_argument('--algorithm', required=True, choices=supported.algorithms)\n", + "parser.add_argument('--root_dir', required=True,\n", + " help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).')\n", + "\n", + "# Dataset\n", + "parser.add_argument('--split_scheme', help='Identifies how the train/val/test split is constructed. Choices are dataset-specific.')\n", + "parser.add_argument('--dataset_kwargs', nargs='*', action=ParseKwargs, default={})\n", + "parser.add_argument('--download', default=False, type=parse_bool, const=True, nargs='?',\n", + " help='If true, tries to downloads the dataset if it does not exist in root_dir.')\n", + "parser.add_argument('--frac', type=float, default=1.0,\n", + " help='Convenience parameter that scales all dataset splits down to the specified fraction, for development purposes.')\n", + "\n", + "# Loaders\n", + "parser.add_argument('--loader_kwargs', nargs='*', action=ParseKwargs, default={})\n", + "parser.add_argument('--train_loader', choices=['standard', 'group'])\n", + "parser.add_argument('--uniform_over_groups', type=parse_bool, const=True, nargs='?')\n", + "parser.add_argument('--distinct_groups', type=parse_bool, const=True, nargs='?')\n", + "parser.add_argument('--n_groups_per_batch', type=int)\n", + "parser.add_argument('--batch_size', type=int)\n", + "parser.add_argument('--eval_loader', choices=['standard'], default='standard')\n", + "\n", + "# Model\n", + "parser.add_argument('--model', choices=supported.models)\n", + "parser.add_argument('--model_kwargs', nargs='*', action=ParseKwargs, default={},\n", + " help='keyword arguments for model initialization passed as key1=value1 key2=value2')\n", + "\n", + "# Transforms\n", + "parser.add_argument('--train_transform', choices=supported.transforms)\n", + "parser.add_argument('--eval_transform', choices=supported.transforms)\n", + "parser.add_argument('--target_resolution', nargs='+', type=int, help='target resolution. for example --target_resolution 224 224 for standard resnet.')\n", + "parser.add_argument('--resize_scale', type=float)\n", + "parser.add_argument('--max_token_length', type=int)\n", + "\n", + "# Objective\n", + "parser.add_argument('--loss_function', choices = supported.losses)\n", + "\n", + "# Algorithm\n", + "parser.add_argument('--groupby_fields', nargs='+')\n", + "parser.add_argument('--group_dro_step_size', type=float)\n", + "parser.add_argument('--coral_penalty_weight', type=float)\n", + "parser.add_argument('--irm_lambda', type=float)\n", + "parser.add_argument('--irm_penalty_anneal_iters', type=int)\n", + "parser.add_argument('--algo_log_metric')\n", + "\n", + "# Model selection\n", + "parser.add_argument('--val_metric')\n", + "parser.add_argument('--val_metric_decreasing', type=parse_bool, const=True, nargs='?')\n", + "\n", + "# Optimization\n", + "parser.add_argument('--n_epochs', type=int)\n", + "parser.add_argument('--optimizer', choices=supported.optimizers)\n", + "parser.add_argument('--lr', type=float)\n", + "parser.add_argument('--weight_decay', type=float)\n", + "parser.add_argument('--max_grad_norm', type=float)\n", + "parser.add_argument('--optimizer_kwargs', nargs='*', action=ParseKwargs, default={})\n", + "\n", + "# Scheduler\n", + "parser.add_argument('--scheduler', choices=supported.schedulers)\n", + "parser.add_argument('--scheduler_kwargs', nargs='*', action=ParseKwargs, default={})\n", + "parser.add_argument('--scheduler_metric_split', choices=['train', 'val'], default='val')\n", + "parser.add_argument('--scheduler_metric_name')\n", + "\n", + "# Evaluation\n", + "parser.add_argument('--evaluate_all_splits', type=parse_bool, const=True, nargs='?', default=True)\n", + "parser.add_argument('--eval_splits', nargs='+', default=[])\n", + "parser.add_argument('--eval_only', type=parse_bool, const=True, nargs='?', default=False)\n", + "parser.add_argument('--eval_epoch', default=None, type=int)\n", + "\n", + "# Misc\n", + "parser.add_argument('--device', type=int, default=0)\n", + "parser.add_argument('--seed', type=int, default=0)\n", + "parser.add_argument('--log_dir', default='./logs')\n", + "parser.add_argument('--log_every', default=50, type=int)\n", + "parser.add_argument('--save_step', type=int)\n", + "parser.add_argument('--save_best', type=parse_bool, const=True, nargs='?', default=True)\n", + "parser.add_argument('--save_last', type=parse_bool, const=True, nargs='?', default=True)\n", + "parser.add_argument('--no_group_logging', type=parse_bool, const=True, nargs='?')\n", + "parser.add_argument('--use_wandb', type=parse_bool, const=True, nargs='?', default=False)\n", + "parser.add_argument('--progress_bar', type=parse_bool, const=True, nargs='?', default=False)\n", + "parser.add_argument('--resume', type=parse_bool, const=True, nargs='?', default=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "argstr = \"--dataset camelyon17 --algorithm ERM --root_dir data\"" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "config = parser.parse_args(argstr.split())\n", + "config = populate_defaults(config)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'camelyon17'" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.cuda.is_available()\n", + "config.dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'amazon': wilds.datasets.amazon_dataset.AmazonDataset,\n", + " 'camelyon17': wilds.datasets.camelyon17_dataset.Camelyon17Dataset,\n", + " 'celebA': wilds.datasets.celebA_dataset.CelebADataset,\n", + " 'civilcomments': wilds.datasets.civilcomments_dataset.CivilCommentsDataset,\n", + " 'iwildcam': wilds.datasets.iwildcam_dataset.IWildCamDataset,\n", + " 'waterbirds': wilds.datasets.waterbirds_dataset.WaterbirdsDataset,\n", + " 'yelp': wilds.datasets.yelp_dataset.YelpDataset,\n", + " 'ogb-molpcba': wilds.datasets.ogbmolpcba_dataset.OGBPCBADataset,\n", + " 'poverty': wilds.datasets.poverty_dataset.PovertyMapDataset,\n", + " 'fmow': wilds.datasets.fmow_dataset.FMoWDataset,\n", + " 'bdd100k': wilds.datasets.bdd100k_dataset.BDD100KDataset,\n", + " 'encodeTFBS': wilds.datasets.encodetfbs_dataset.EncodeTFBSDataset}" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "supported.datasets#[config.dataset]" + ] + }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "# set device\n", + "config.device = torch.device(\"cuda:\" + str(config.device)) if torch.cuda.is_available() else torch.device(\"cpu\")\n", + "\n", + "## Initialize logs\n", + "if os.path.exists(config.log_dir) and config.resume:\n", + " resume=True\n", + " mode='a'\n", + "elif os.path.exists(config.log_dir) and config.eval_only:\n", + " resume=False\n", + " mode='a'\n", + "else:\n", + " resume=False\n", + " mode='w'\n", + "\n", + "if not os.path.exists(config.log_dir):\n", + " os.makedirs(config.log_dir)\n", + "logger = Logger(os.path.join(config.log_dir, 'log.txt'), mode)\n", + "\n", + "# Record config\n", + "log_config(config, logger)\n", + "\n", + "# Set random seed\n", + "set_seed(config.seed)\n", + "\n", + "# Data\n", + "full_dataset = supported.datasets[config.dataset](\n", + " root_dir=config.root_dir,\n", + " download=config.download,\n", + " split_scheme=config.split_scheme,\n", + " **config.dataset_kwargs)\n", + "\n", + "# To implement data augmentation (i.e., have different transforms\n", + "# at training time vs. test time), modify these two lines:\n", + "train_transform = initialize_transform(\n", + " transform_name=config.train_transform,\n", + " config=config,\n", + " dataset=full_dataset)\n", + "eval_transform = initialize_transform(\n", + " transform_name=config.eval_transform,\n", + " config=config,\n", + " dataset=full_dataset)\n", + "\n", + "train_grouper = CombinatorialGrouper(\n", + " dataset=full_dataset,\n", + " groupby_fields=config.groupby_fields)\n", + "\n", + "datasets = defaultdict(dict)\n", + "for split in full_dataset.split_dict.keys():\n", + " if split=='train':\n", + " transform = train_transform\n", + " verbose = True\n", + " elif split == 'val':\n", + " transform = eval_transform\n", + " verbose = True\n", + " else:\n", + " transform = eval_transform\n", + " verbose = False\n", + " # Get subset\n", + " datasets[split]['dataset'] = full_dataset.get_subset(\n", + " split,\n", + " frac=config.frac,\n", + " transform=transform)\n", + "\n", + " if split == 'train':\n", + " datasets[split]['loader'] = get_train_loader(\n", + " loader=config.train_loader,\n", + " dataset=datasets[split]['dataset'],\n", + " batch_size=config.batch_size,\n", + " uniform_over_groups=config.uniform_over_groups,\n", + " grouper=train_grouper,\n", + " distinct_groups=config.distinct_groups,\n", + " n_groups_per_batch=config.n_groups_per_batch,\n", + " **config.loader_kwargs)\n", + " else:\n", + " datasets[split]['loader'] = get_eval_loader(\n", + " loader=config.eval_loader,\n", + " dataset=datasets[split]['dataset'],\n", + " grouper=train_grouper,\n", + " batch_size=config.batch_size,\n", + " **config.loader_kwargs)\n", + "\n", + " # Set fields\n", + " datasets[split]['split'] = split\n", + " datasets[split]['name'] = full_dataset.split_names[split]\n", + " datasets[split]['verbose'] = verbose\n", + " # Loggers\n", + " # Loggers\n", + " datasets[split]['eval_logger'] = BatchLogger(\n", + " os.path.join(config.log_dir, f'{split}_eval.csv'), mode=mode, use_wandb=(config.use_wandb and verbose))\n", + " datasets[split]['algo_logger'] = BatchLogger(\n", + " os.path.join(config.log_dir, f'{split}_algo.csv'), mode=mode, use_wandb=(config.use_wandb and verbose))\n", + "\n", + " if config.use_wandb:\n", + " initialize_wandb(config)\n", + "\n", + "# Logging dataset info\n", + "if config.no_group_logging and full_dataset.is_classification and full_dataset.y_size==1:\n", + " log_grouper = CombinatorialGrouper(\n", + " dataset=full_dataset,\n", + " groupby_fields=['y'])\n", + "elif config.no_group_logging:\n", + " log_grouper = None\n", + "else:\n", + " log_grouper = train_grouper\n", + "log_group_data(datasets, log_grouper, logger)\n", + "\n", + "## Initialize algorithm\n", + "algorithm = initialize_algorithm(\n", + " config=config,\n", + " datasets=datasets,\n", + " train_grouper=train_grouper)\n", + "\n", + "if not config.eval_only:\n", + " ## Load saved results if resuming\n", + " resume_success = False\n", + " if resume:\n", + " save_path = os.path.join(config.log_dir, 'last_model.pth')\n", + " if not os.path.exists(save_path):\n", + " epochs = [\n", + " int(file.split('_')[0])\n", + " for file in os.listdir(config.log_dir) if file.endswith('.pth')]\n", + " if len(epochs) > 0:\n", + " latest_epoch = max(epochs)\n", + " save_path = os.path.join(config.log_dir, f'{latest_epoch}_model.pth')\n", + " try:\n", + " prev_epoch, best_val_metric = load(algorithm, save_path)\n", + " epoch_offset = prev_epoch + 1\n", + " logger.write(f'Resuming from epoch {epoch_offset} with best val metric {best_val_metric}')\n", + " resume_success = True\n", + " except FileNotFoundError:\n", + " pass\n", + "\n", + " if resume_success == False:\n", + " epoch_offset=0\n", + " best_val_metric=None\n", + "\n", + "\n", + " train(\n", + " algorithm=algorithm,\n", + " datasets=datasets,\n", + " general_logger=logger,\n", + " config=config,\n", + " epoch_offset=epoch_offset,\n", + " best_val_metric=best_val_metric)\n", + "else:\n", + " if config.eval_epoch is None:\n", + " eval_model_path = os.path.join(config.log_dir, 'best_model.pth')\n", + " else:\n", + " eval_model_path = os.path.join(config.log_dir, f'{config.eval_epoch}_model.pth')\n", + " best_epoch, best_val_metric = load(algorithm, eval_model_path)\n", + " if config.eval_epoch is None:\n", + " epoch = best_epoch\n", + " else:\n", + " epoch = config.eval_epoch\n", + " evaluate(\n", + " algorithm=algorithm,\n", + " datasets=datasets,\n", + " epoch=epoch,\n", + " general_logger=logger,\n", + " config=config)\n", + "\n", + "logger.close()\n", + "for split in datasets:\n", + " datasets[split]['eval_logger'].close()\n", + " datasets[split]['algo_logger'].close()" + ] }, { "cell_type": "markdown", @@ -1010,23 +1342,23 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 2", + "display_name": "Python 3", "language": "python", - "name": "python2" + "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", - "version": 2 + "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.13" + "pygments_lexer": "ipython3", + "version": "3.8.5" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } From 039241d0dc2a7d139b32b733fbc854e442a4f247 Mon Sep 17 00:00:00 2001 From: aikanor Date: Tue, 9 Feb 2021 09:34:18 -0800 Subject: [PATCH 011/244] integration 7/ --- examples/configs/datasets.py | 24 +- examples/configs/supported.py | 2 +- examples/sbox_run_expt.ipynb | 508 ++------------------------- sandbox_data.ipynb | 24 +- wilds/datasets/encodetfbs_dataset.py | 2 +- 5 files changed, 62 insertions(+), 498 deletions(-) diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index 1d15c7af..824b6f3d 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -50,8 +50,6 @@ 'weight_decay': 0.01, 'n_epochs': 5, 'n_groups_per_batch': 2, - 'irm_lambda': 1.0, - 'coral_penalty_weight': 0.1, 'algo_log_metric': 'accuracy', }, 'celebA': { @@ -89,6 +87,28 @@ 'algo_log_metric': 'accuracy', 'max_token_length': 300, }, + 'encode-tfbs': { + 'split_scheme': 'official', + 'model': 'beagle', + 'model_kwargs': {'pretrained': False}, + 'train_transform': None, + 'eval_transform': None, + 'loss_function': 'cross_entropy', + 'groupby_fields': ['hospital'], + 'val_metric': 'ap', + 'val_metric_decreasing': False, + 'optimizer': 'Adam', + # 'optimizer_kwargs': { }, + 'scheduler': None, + 'batch_size': 128, + 'lr': 0.001, + 'weight_decay': 0.01, + 'n_epochs': 1, + 'n_groups_per_batch': 2, + # 'irm_lambda': 1.0, + # 'coral_penalty_weight': 0.1, + # 'algo_log_metric': 'accuracy', + }, 'fmow': { 'split_scheme': 'official', 'dataset_kwargs': { diff --git a/examples/configs/supported.py b/examples/configs/supported.py index d39e096e..fd4d6a63 100644 --- a/examples/configs/supported.py +++ b/examples/configs/supported.py @@ -30,7 +30,7 @@ 'poverty': PovertyMapDataset, 'fmow': FMoWDataset, 'bdd100k': BDD100KDataset, - 'encodeTFBS': EncodeTFBSDataset, + 'encode-tfbs': EncodeTFBSDataset, } losses = { diff --git a/examples/sbox_run_expt.ipynb b/examples/sbox_run_expt.ipynb index e74eeefb..9c19a1e9 100644 --- a/examples/sbox_run_expt.ipynb +++ b/examples/sbox_run_expt.ipynb @@ -42,7 +42,7 @@ { "data": { "text/plain": [ - "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" + "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" ] }, "execution_count": 2, @@ -140,73 +140,53 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ - "argstr = \"--dataset camelyon17 --algorithm ERM --root_dir data\"" + "argstr_camelyon = \"--dataset camelyon17 --algorithm ERM --root_dir data\"\n", + "argstr_encode = \"--dataset encode-tfbs --algorithm ERM --root_dir data\"" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ - "config = parser.parse_args(argstr.split())\n", - "config = populate_defaults(config)" + "config_camelyon = parser.parse_args(argstr_camelyon.split())\n", + "config_encode = parser.parse_args(argstr_encode.split())" ] }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 10, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'camelyon17'" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "torch.cuda.is_available()\n", - "config.dataset" + "config_camelyon = populate_defaults(config_camelyon)\n", + "config_encode = populate_defaults(config_encode)" ] }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "{'amazon': wilds.datasets.amazon_dataset.AmazonDataset,\n", - " 'camelyon17': wilds.datasets.camelyon17_dataset.Camelyon17Dataset,\n", - " 'celebA': wilds.datasets.celebA_dataset.CelebADataset,\n", - " 'civilcomments': wilds.datasets.civilcomments_dataset.CivilCommentsDataset,\n", - " 'iwildcam': wilds.datasets.iwildcam_dataset.IWildCamDataset,\n", - " 'waterbirds': wilds.datasets.waterbirds_dataset.WaterbirdsDataset,\n", - " 'yelp': wilds.datasets.yelp_dataset.YelpDataset,\n", - " 'ogb-molpcba': wilds.datasets.ogbmolpcba_dataset.OGBPCBADataset,\n", - " 'poverty': wilds.datasets.poverty_dataset.PovertyMapDataset,\n", - " 'fmow': wilds.datasets.fmow_dataset.FMoWDataset,\n", - " 'bdd100k': wilds.datasets.bdd100k_dataset.BDD100KDataset,\n", - " 'encodeTFBS': wilds.datasets.encodetfbs_dataset.EncodeTFBSDataset}" + "wilds.datasets.encodetfbs_dataset.EncodeTFBSDataset" ] }, - "execution_count": 18, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "supported.datasets#[config.dataset]" + "supported.datasets[config_encode.dataset]\n" ] }, { @@ -255,8 +235,15 @@ "eval_transform = initialize_transform(\n", " transform_name=config.eval_transform,\n", " config=config,\n", - " dataset=full_dataset)\n", - "\n", + " dataset=full_dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ "train_grouper = CombinatorialGrouper(\n", " dataset=full_dataset,\n", " groupby_fields=config.groupby_fields)\n", @@ -485,167 +472,6 @@ " print(ct, time.time() - itime)" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class Beagle2(nn.Module):\n", - " \"\"\"\n", - " Neural net models over genomic sequence.\n", - " Input:\n", - " - sequence_length: int (default 1000) \n", - " - Shape: (N, 5, sequence_length, 1) with batch size N.\n", - " \n", - " Output:\n", - " - prediction (Tensor): float torch tensor of shape (N, )\n", - " \n", - " TODO: Finish docstring.\n", - " \"\"\"\n", - " def __init__(self):\n", - " \"\"\"\n", - " Parameters\n", - " ----------\n", - " sequence_length : int\n", - " n_genomic_features : int\n", - " \"\"\"\n", - " super(Beagle2, self).__init__()\n", - "\n", - " self.dropout = 0.3\n", - " self.num_cell_types = 1\n", - " self.conv1 = nn.Conv2d(5, 300, (19, 1), stride = (1, 1), padding=(9,0))\n", - " self.conv2 = nn.Conv2d(300, 200, (11, 1), stride = (1, 1), padding = (5,0))\n", - " self.conv3 = nn.Conv2d(200, 200, (7, 1), stride = (1, 1), padding = (4,0))\n", - " self.bn1 = nn.BatchNorm2d(300)\n", - " self.bn2 = nn.BatchNorm2d(200)\n", - " self.bn3 = nn.BatchNorm2d(200)\n", - " self.maxpool1 = nn.MaxPool2d((3, 1))\n", - " self.maxpool2 = nn.MaxPool2d((4, 1))\n", - " self.maxpool3 = nn.MaxPool2d((4, 1))\n", - "\n", - " self.fc1 = nn.Linear(4200, 1000)\n", - " self.bn4 = nn.BatchNorm1d(1000)\n", - "\n", - " self.fc2 = nn.Linear(1000, 1000)\n", - " self.bn5 = nn.BatchNorm1d(1000)\n", - "\n", - " self.fc3 = nn.Linear(1000, self.num_cell_types)\n", - "\n", - " def forward(self, s):\n", - " s = s.permute(0, 2, 1).contiguous() # batch_size x 4 x 1000\n", - " s = s.view(-1, 5, 1000, 1) # batch_size x 4 x 1000 x 1 [4 channels]\n", - " s = self.maxpool1(F.relu(self.bn1(self.conv1(s)))) # batch_size x 300 x 333 x 1\n", - " s = self.maxpool2(F.relu(self.bn2(self.conv2(s)))) # batch_size x 200 x 83 x 1\n", - " s = self.maxpool3(F.relu(self.bn3(self.conv3(s)))) # batch_size x 200 x 21 x 1\n", - " s = s.view(-1, 4200)\n", - " conv_out = s\n", - "\n", - " s = F.dropout(F.relu(self.bn4(self.fc1(s))), p=self.dropout, training=self.training) # batch_size x 1000\n", - " #s = F.dropout(F.relu(self.bn5(self.fc2(s))), p=self.dropout, training=self.training) # batch_size x 1000\n", - " \n", - " \n", - " s = self.fc3(s)\n", - "\n", - " return s, conv_out\n", - "\n", - "\n", - "class DanQ(nn.Module):\n", - " def __init__(self, sequence_length, n_genomic_features):\n", - " \"\"\"\n", - " Parameters\n", - " ----------\n", - " sequence_length : int\n", - " Input sequence length\n", - " n_genomic_features : int\n", - " Total number of features to predict\n", - " \"\"\"\n", - " super(DanQ, self).__init__()\n", - " self.nnet = nn.Sequential(\n", - " nn.Conv1d(4, 320, kernel_size=26),\n", - " nn.ReLU(inplace=True),\n", - " nn.MaxPool1d(\n", - " kernel_size=13, stride=13),\n", - " nn.Dropout(0.2))\n", - "\n", - " self.bdlstm = nn.Sequential(\n", - " nn.LSTM(\n", - " 320, 320, num_layers=1, batch_first=True, bidirectional=True))\n", - "\n", - " self._n_channels = math.floor(\n", - " (sequence_length - 25) / 13)\n", - " self.classifier = nn.Sequential(\n", - " nn.Dropout(0.5),\n", - " nn.Linear(self._n_channels * 640, 925),\n", - " nn.ReLU(inplace=True),\n", - " nn.Linear(925, n_genomic_features),\n", - " nn.Sigmoid())\n", - "\n", - " def forward(self, x):\n", - " \"\"\"Forward propagation of a batch.\n", - " \"\"\"\n", - " out = self.nnet(x)\n", - " reshape_out = out.transpose(0, 1).transpose(0, 2)\n", - " out, _ = self.bdlstm(reshape_out)\n", - " out = out.transpose(0, 1)\n", - " reshape_out = out.contiguous().view(\n", - " out.size(0), 640 * self._n_channels)\n", - " predict = self.classifier(reshape_out)\n", - " return predict\n", - "\n", - "\n", - "class DeepSEA(nn.Module):\n", - " def __init__(self, sequence_length, n_genomic_features):\n", - " \"\"\"\n", - " Parameters\n", - " ----------\n", - " sequence_length : int\n", - " n_genomic_features : int\n", - " \"\"\"\n", - " super(DeepSEA, self).__init__()\n", - " conv_kernel_size = 8\n", - " pool_kernel_size = 4\n", - "\n", - " self.conv_net = nn.Sequential(\n", - " nn.Conv1d(4, 320, kernel_size=conv_kernel_size),\n", - " nn.ReLU(inplace=True),\n", - " nn.MaxPool1d(\n", - " kernel_size=pool_kernel_size, stride=pool_kernel_size),\n", - " nn.Dropout(p=0.2),\n", - "\n", - " nn.Conv1d(320, 480, kernel_size=conv_kernel_size),\n", - " nn.ReLU(inplace=True),\n", - " nn.MaxPool1d(\n", - " kernel_size=pool_kernel_size, stride=pool_kernel_size),\n", - " nn.Dropout(p=0.2),\n", - "\n", - " nn.Conv1d(480, 960, kernel_size=conv_kernel_size),\n", - " nn.ReLU(inplace=True),\n", - " nn.Dropout(p=0.5))\n", - "\n", - " reduce_by = conv_kernel_size - 1\n", - " pool_kernel_size = float(pool_kernel_size)\n", - " self.n_channels = int(\n", - " np.floor(\n", - " (np.floor(\n", - " (sequence_length - reduce_by) / pool_kernel_size)\n", - " - reduce_by) / pool_kernel_size)\n", - " - reduce_by)\n", - " self.classifier = nn.Sequential(\n", - " nn.Linear(960 * self.n_channels, n_genomic_features),\n", - " nn.ReLU(inplace=True),\n", - " nn.Linear(n_genomic_features, n_genomic_features),\n", - " nn.Sigmoid())\n", - "\n", - " def forward(self, x):\n", - " \"\"\"Forward propagation of a batch.\n", - " \"\"\"\n", - " out = self.conv_net(x)\n", - " reshape_out = out.view(out.size(0), 960 * self.n_channels)\n", - " predict = self.classifier(reshape_out)\n", - " return predict" - ] - }, { "cell_type": "code", "execution_count": 78, @@ -1052,292 +878,6 @@ "source": [ "from examples.models.model_attributes import model_attributes" ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "ename": "ModuleNotFoundError", - "evalue": "No module named 'utils'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel_attributes\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mmodel_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mset_seed\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mLogger\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mCSVBatchLogger\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlog_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mParseKwargs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mload\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 21\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 22\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mdataset_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0moptimizer_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/dr_benchmark/examples/train.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0msave\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mIPython\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'utils'" - ] - } - ], - "source": [ - "import os, csv\n", - "import time\n", - "import argparse\n", - "import IPython\n", - "import pandas as pd\n", - "import torch\n", - "import torch.nn as nn\n", - "import torchvision\n", - "import sys\n", - "from collections import defaultdict\n", - "\n", - "# TODO: Replace this once we make wilds into an installed package\n", - "sys.path.insert(1, os.path.join(sys.path[0], '..'))\n", - "\n", - "from wilds.common.data_loaders import get_train_loader, get_eval_loader\n", - "from wilds.common.grouper import CombinatorialGrouper\n", - "from wilds.common.utils import get_counts\n", - "\n", - "from examples.models.model_attributes import model_attributes\n", - "from examples.utils import set_seed, Logger, CSVBatchLogger, log_args, ParseKwargs, load\n", - "from examples.train import train\n", - "from examples.data import dataset_attributes\n", - "from examples.optimizer import optimizer_attributes\n", - "from examples.scheduler import scheduler_attributes\n", - "from examples.loss import losses\n", - "from examples.utils import log_group_data\n", - "from examples.algorithms.constructors import algorithm_constructors\n", - "\n", - "\n", - "def initialize_algorithm(args, datasets, train_grouper):\n", - " train_dataset = datasets['train']['dataset']\n", - " train_loader = datasets['train']['loader']\n", - "\n", - " # Configure the final layer of the networks used\n", - " # The code below are defaults. Edit this if you need special config for your model.\n", - " if (train_dataset.is_classification) and (train_dataset.y_size == 1):\n", - " # For single-task classification, we have one output per class\n", - " d_out = train_dataset.n_classes\n", - " elif (not train_dataset.is_classification):\n", - " # For regression, we have one output per target dimension\n", - " d_out = train_dataset.y_size\n", - " else:\n", - " # TODO: Handle dataset-specific multi-task stuff here, e.g., for OGB\n", - " pass\n", - "\n", - " # Sanity checking input args\n", - " if args.algorithm == 'groupDRO':\n", - " assert args.train_loader_kwargs['uniform_over_groups']\n", - " elif args.algorithm in ['deepCORAL', 'IRM']:\n", - " assert args.train_loader == 'group'\n", - " assert args.train_loader_kwargs['uniform_over_groups']\n", - " assert args.train_loader_kwargs['distinct_groups']\n", - "\n", - " # Other config\n", - " n_train_steps = len(train_loader) * args.n_epochs\n", - " prediction_fn = dataset_attributes[args.dataset]['prediction_fn']\n", - " loss = losses[args.loss_function]\n", - " metric_constructor = dataset_attributes[args.dataset]['metric']\n", - " train_g = train_grouper.metadata_to_group(train_dataset.metadata_array)\n", - " is_group_in_train = get_counts(train_g, train_grouper.n_groups) > 0\n", - " algorithm_constructor = algorithm_constructors[args.algorithm]\n", - " algorithm = algorithm_constructor(\n", - " args=args,\n", - " d_out=d_out,\n", - " grouper=train_grouper,\n", - " prediction_fn=prediction_fn,\n", - " loss=loss,\n", - " metric_constructor=metric_constructor,\n", - " n_train_steps=n_train_steps,\n", - " is_group_in_train=is_group_in_train)\n", - " return algorithm" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "parser = argparse.ArgumentParser()\n", - "\n", - "# Dataset\n", - "parser.add_argument('-d', '--dataset', choices=dataset_attributes.keys(), required=True)\n", - "parser.add_argument('--split_scheme', default='standard',\n", - " help='Identifies how the train/val/test split is constructed. Choices are dataset-specific.')\n", - "parser.add_argument('--root_dir', default=None, required=True,\n", - " help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).')\n", - "parser.add_argument('--download', default=False, action='store_true',\n", - " help='If true, tries to downloads the dataset if it does not exist in root_dir.')\n", - "parser.add_argument('--frac', type=float, default=1.0,\n", - " help='Convenience parameter that scales all dataset splits down to the specified fraction, for development purposes.')\n", - "\n", - "# Loaders\n", - "parser.add_argument('--train_loader', choices=['standard', 'group'], default='standard')\n", - "parser.add_argument('--train_loader_kwargs', nargs='*', action=ParseKwargs, default={})\n", - "parser.add_argument('--eval_loader', choices=['standard', 'group'], default='standard')\n", - "parser.add_argument('--batch_size', type=int, default=32)\n", - "\n", - "# Model\n", - "parser.add_argument(\n", - " '--model',\n", - " choices=model_attributes.keys(),\n", - " default='resnet50')\n", - "parser.add_argument('--model_kwargs', nargs='*', action=ParseKwargs, default={},\n", - " help='keyword arguments for model initialization passed as key1=value1 key2=value2')\n", - "parser.add_argument('--train_from_scratch', action='store_true', default=False)\n", - "\n", - "# Algorithm and objective\n", - "parser.add_argument('--algorithm', required=True, choices=algorithm_constructors.keys())\n", - "parser.add_argument('--algorithm_kwargs', nargs='*', action=ParseKwargs, default={})\n", - "parser.add_argument('--groupby_fields', nargs='+', default=None)\n", - "parser.add_argument('--loss_function', required=True, choices = losses.keys()) #TODO: make default\n", - "parser.add_argument('--val_metric', default=None)\n", - "\n", - "# Optimization\n", - "parser.add_argument('--n_epochs', type=int, default=4)\n", - "parser.add_argument('--optimizer', default=None, choices=optimizer_attributes.keys())\n", - "parser.add_argument('--lr', type=float, required=True)\n", - "parser.add_argument('--weight_decay', type=float, required=True)\n", - "parser.add_argument('--optimizer_kwargs', nargs='*', action=ParseKwargs, default={})\n", - "parser.add_argument('--scheduler', default=None, choices=scheduler_attributes.keys())\n", - "parser.add_argument('--scheduler_kwargs', nargs='*', action=ParseKwargs, default={})\n", - "parser.add_argument('--scheduler_metric_split', choices=['train', 'val'], default='val')\n", - "parser.add_argument('--scheduler_metric_name')\n", - "\n", - "# Evaluation\n", - "parser.add_argument('--evaluate_all_splits', action='store_true', default=False)\n", - "parser.add_argument('--additional_eval_splits', nargs='+', default=[])\n", - "\n", - "# Misc\n", - "parser.add_argument('--device', default='cuda')\n", - "parser.add_argument('--seed', type=int, default=0)\n", - "parser.add_argument('--log_dir', default='./logs')\n", - "parser.add_argument('--log_every', default=50, type=int)\n", - "parser.add_argument('--save_step', type=int, default=None)\n", - "parser.add_argument('--save_best', action='store_true', default=False)\n", - "parser.add_argument('--save_last', action='store_true', default=False)\n", - "parser.add_argument('--save_outputs', action='store_true', default=False)\n", - "parser.add_argument('--no_group_logging', action='store_true', default=False)\n", - "\n", - "parser.add_argument('--resume', default=False, action='store_true')\n", - "\n", - "args = parser.parse_args()\n", - "\n", - "# Set defaults\n", - "if args.groupby_fields is None:\n", - " args.no_group_logging = True\n", - "if args.val_metric is None:\n", - " args.val_metric = dataset_attributes[args.dataset]['val_metric']\n", - "\n", - "## Initialize logs\n", - "if os.path.exists(args.log_dir) and args.resume:\n", - " resume=True\n", - " mode='a'\n", - "else:\n", - " resume=False\n", - " mode='w'\n", - "if not os.path.exists(args.log_dir):\n", - " os.makedirs(args.log_dir)\n", - "logger = Logger(os.path.join(args.log_dir, 'log.txt'), mode)\n", - "\n", - "# Record args\n", - "log_args(args, logger)\n", - "\n", - "# Set random seed\n", - "set_seed(args.seed)\n", - "\n", - "# Data\n", - "full_dataset = dataset_attributes[args.dataset]['constructor'](\n", - " root_dir=args.root_dir,\n", - " download=args.download,\n", - " split_scheme=args.split_scheme)\n", - "\n", - "# To implement data augmentation (i.e., have different transforms\n", - "# at training time vs. test time), modify these two lines:\n", - "train_transform = dataset_attributes[args.dataset]['transform'](args.model)\n", - "eval_transform = dataset_attributes[args.dataset]['transform'](args.model)\n", - "\n", - "train_grouper = CombinatorialGrouper(\n", - " dataset=full_dataset,\n", - " groupby_fields=args.groupby_fields)\n", - "\n", - "datasets = defaultdict(dict)\n", - "for split in full_dataset.split_dict.keys():\n", - " if split=='train':\n", - " transform = train_transform\n", - " verbose = True\n", - " elif split == 'val':\n", - " transform = eval_transform\n", - " verbose = True\n", - " else:\n", - " transform = eval_transform\n", - " verbose = False\n", - " # Get subset\n", - " datasets[split]['dataset'] = full_dataset.get_subset(\n", - " split,\n", - " frac=args.frac,\n", - " transform=transform)\n", - "\n", - " # Get loader\n", - " shared_loader_kwargs = {\n", - " 'num_workers': 4,\n", - " 'pin_memory': True,\n", - " 'batch_size': args.batch_size,\n", - " 'collate_fn': dataset_attributes[args.dataset]['collate']\n", - " }\n", - "\n", - " if split == 'train':\n", - " datasets[split]['loader'] = get_train_loader(\n", - " loader=args.train_loader,\n", - " dataset=datasets[split]['dataset'],\n", - " grouper=train_grouper,\n", - " train_loader_kwargs=args.train_loader_kwargs,\n", - " **shared_loader_kwargs)\n", - " else:\n", - " datasets[split]['loader'] = get_eval_loader(\n", - " loader=args.eval_loader,\n", - " dataset=datasets[split]['dataset'],\n", - " grouper=train_grouper,\n", - " **shared_loader_kwargs)\n", - "\n", - " # Set fields\n", - " datasets[split]['split'] = split\n", - " datasets[split]['name'] = full_dataset.split_names[split]\n", - " datasets[split]['verbose'] = verbose\n", - " # Loggers\n", - " datasets[split]['eval_logger'] = CSVBatchLogger(\n", - " os.path.join(args.log_dir, f'{split}_eval.csv'), mode=mode)\n", - " datasets[split]['algo_logger'] = CSVBatchLogger(\n", - " os.path.join(args.log_dir, f'{split}_algo.csv'), mode=mode)\n", - "\n", - "# Logging dataset info\n", - "if args.no_group_logging and full_dataset.is_classification and full_dataset.y_size==1:\n", - " log_grouper = CombinatorialGrouper(\n", - " dataset=full_dataset,\n", - " groupby_fields=['y'])\n", - "elif args.no_group_logging:\n", - " log_grouper = None\n", - "else:\n", - " log_grouper = train_grouper\n", - "log_group_data(args, datasets, log_grouper, logger)\n", - "\n", - "## Initialize algorithm\n", - "algorithm = initialize_algorithm(args, datasets, train_grouper)\n", - "\n", - "## Load saved results if resuming\n", - "if resume:\n", - " save_path = os.path.join(args.log_dir, 'last_model.pth')\n", - " prev_epoch, best_val_metric = load(algorithm, save_path)\n", - " epoch_offset = prev_epoch + 1\n", - "else:\n", - " epoch_offset=0\n", - " best_val_metric=None\n", - "\n", - "train(algorithm,\n", - " datasets,\n", - " logger,\n", - " args,\n", - " epoch_offset=epoch_offset,\n", - " best_val_metric=best_val_metric)\n", - "\n", - "logger.close()\n", - "for split in datasets:\n", - " datasets[split]['eval_logger'].close()\n", - " datasets[split]['algo_logger'].close()" - ] } ], "metadata": { diff --git a/sandbox_data.ipynb b/sandbox_data.ipynb index 0a9806a6..4203968d 100644 --- a/sandbox_data.ipynb +++ b/sandbox_data.ipynb @@ -25,11 +25,15 @@ ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": {}, - "outputs": [], - "source": [] + "source": [ + "# TODOs\n", + "\n", + "- change sequence length of model\n", + " - examples/configs/model.py\n", + " - examples/models/CNN_genome.py" + ] }, { "cell_type": "markdown", @@ -590,23 +594,23 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 2", + "display_name": "Python 3", "language": "python", - "name": "python2" + "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", - "version": 2 + "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.13" + "pygments_lexer": "ipython3", + "version": "3.8.5" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index 062e468a..e55fba11 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -127,7 +127,7 @@ def get_input(self, idx): Computes this from: (1) sequence features in self._seq_bp (2) DNase features in self._dnase_allcelltypes - (3) Metadata for the index (location along the genome with 1kb window width) + (3) Metadata for the index (location along the genome with 200bp window width) """ this_metadata = self._metadata_df.iloc[idx, :] flank_size = 400 From f5f549a224e728dd5c1afe94b46dd51b7c24dcce Mon Sep 17 00:00:00 2001 From: aikanor Date: Tue, 9 Feb 2021 10:53:53 -0800 Subject: [PATCH 012/244] integration 8/ --- examples/sbox_run_expt.ipynb | 149 ++++++++++++++++++++++++--- wilds/datasets/encodetfbs_dataset.py | 17 ++- 2 files changed, 150 insertions(+), 16 deletions(-) diff --git a/examples/sbox_run_expt.ipynb b/examples/sbox_run_expt.ipynb index 9c19a1e9..2fc71286 100644 --- a/examples/sbox_run_expt.ipynb +++ b/examples/sbox_run_expt.ipynb @@ -42,7 +42,7 @@ { "data": { "text/plain": [ - "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" + "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" ] }, "execution_count": 2, @@ -140,7 +140,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -150,50 +150,113 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "config_camelyon = parser.parse_args(argstr_camelyon.split())\n", - "config_encode = parser.parse_args(argstr_encode.split())" + "config_encode = parser.parse_args(argstr_encode.split())\n", + "config_camelyon = populate_defaults(config_camelyon)\n", + "config_encode = populate_defaults(config_encode)" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ - "torch.cuda.is_available()\n", - "config_camelyon = populate_defaults(config_camelyon)\n", - "config_encode = populate_defaults(config_encode)" + "config = config_camelyon" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "wilds.datasets.encodetfbs_dataset.EncodeTFBSDataset" + "'./logs'" ] }, - "execution_count": 12, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "supported.datasets[config_encode.dataset]\n" + "config.log_dir" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Dataset: camelyon17\n", + "Algorithm: ERM\n", + "Root dir: data\n", + "Split scheme: official\n", + "Dataset kwargs: {}\n", + "Download: False\n", + "Frac: 1.0\n", + "Loader kwargs: {'num_workers': 4, 'pin_memory': True}\n", + "Train loader: standard\n", + "Uniform over groups: False\n", + "Distinct groups: None\n", + "N groups per batch: 2\n", + "Batch size: 32\n", + "Eval loader: standard\n", + "Model: densenet121\n", + "Model kwargs: {'pretrained': False}\n", + "Train transform: image_base\n", + "Eval transform: image_base\n", + "Target resolution: (224, 224)\n", + "Resize scale: None\n", + "Max token length: None\n", + "Loss function: cross_entropy\n", + "Groupby fields: ['hospital']\n", + "Group dro step size: None\n", + "Coral penalty weight: None\n", + "Irm lambda: None\n", + "Irm penalty anneal iters: None\n", + "Algo log metric: accuracy\n", + "Val metric: acc_avg\n", + "Val metric decreasing: False\n", + "N epochs: 5\n", + "Optimizer: SGD\n", + "Lr: 0.001\n", + "Weight decay: 0.01\n", + "Max grad norm: None\n", + "Optimizer kwargs: {'momentum': 0.9}\n", + "Scheduler: None\n", + "Scheduler kwargs: {}\n", + "Scheduler metric split: val\n", + "Scheduler metric name: None\n", + "Evaluate all splits: True\n", + "Eval splits: []\n", + "Eval only: False\n", + "Eval epoch: None\n", + "Device: cuda:0\n", + "Seed: 0\n", + "Log dir: ./logs\n", + "Log every: 50\n", + "Save step: None\n", + "Save best: True\n", + "Save last: True\n", + "No group logging: False\n", + "Use wandb: False\n", + "Progress bar: False\n", + "Resume: False\n", + "\n" + ] + } + ], "source": [ "# set device\n", "config.device = torch.device(\"cuda:\" + str(config.device)) if torch.cuda.is_available() else torch.device(\"cpu\")\n", @@ -238,6 +301,64 @@ " dataset=full_dataset)" ] }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "full_dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "image_base None\n" + ] + } + ], + "source": [ + "supported.datasets[config_encode.dataset]\n", + "print(config_camelyon.train_transform, config_encode.train_transform)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "full_dataset.y_size" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index e55fba11..d26d052a 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -8,8 +8,20 @@ class EncodeTFBSDataset(WILDSDataset): """ - EncodeTFBS dataset - Website: https://www.synapse.org/#!Synapse:syn6131484 + ENCODE-DREAM-wilds dataset of transcription factor binding sites. + This is a subset of the dataset from the ENCODE-DREAM in vivo Transcription Factor Binding Site Prediction Challenge. + + Input (x): + 1000-base-pair regions of sequence with a quantified chromatin accessibility readout. + + Label (y): + y is binary. It is 1 if the central 200bp region is bound by the transcription factor MAX, and 0 otherwise. + + Metadata: + Each sequence is annotated with the celltype of origin (a string) and the chromosome of origin (a string). + + Website: + https://www.synapse.org/#!Synapse:syn6131484 """ def __init__(self, root_dir, download, split_scheme): @@ -19,6 +31,7 @@ def __init__(self, root_dir, download, split_scheme): self._y_size = 1 self._n_classes = 2 + # self._tr_chrs = ['chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX'] self._tr_chrs = ['chr2', 'chr9', 'chr11'] self._te_chrs = ['chr1', 'chr8', 'chr21'] self._transcription_factor = 'MAX' From 04e21f95ec1f7feeafe5ef7ef7117e4f580daee3 Mon Sep 17 00:00:00 2001 From: aikanor Date: Tue, 9 Feb 2021 11:37:26 -0800 Subject: [PATCH 013/244] integration 9/ --- sandbox_data.ipynb | 4 +++- wilds/common/metrics/all_metrics.py | 19 +++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/sandbox_data.ipynb b/sandbox_data.ipynb index 4203968d..681c5ec2 100644 --- a/sandbox_data.ipynb +++ b/sandbox_data.ipynb @@ -8,7 +8,7 @@ " - run_expt.py\n", " - configs\n", " - [x] supported.py\n", - " - [ ] model.py\n", + " - [x] model.py\n", " - [ ] datasets.py\n", " - models\n", " - [x] CNN_genome.py\n", @@ -30,6 +30,8 @@ "source": [ "# TODOs\n", "\n", + "- change evaluation metric\n", + "\n", "- change sequence length of model\n", " - examples/configs/model.py\n", " - examples/models/CNN_genome.py" diff --git a/wilds/common/metrics/all_metrics.py b/wilds/common/metrics/all_metrics.py index 3c2af169..85506eab 100644 --- a/wilds/common/metrics/all_metrics.py +++ b/wilds/common/metrics/all_metrics.py @@ -80,6 +80,25 @@ def _compute(self, y_pred, y_true): def worst(self, metrics): return minimum(metrics) +class AveragePrecision(Metric): + def __init__(self, prediction_fn=logits_to_pred, name=None, average='weighted'): + self.prediction_fn = prediction_fn + if name is None: + name = f'avgprec' + if average is not None: + name+=f'-{average}' + self.average = average + super().__init__(name=name) + + def _compute(self, y_pred, y_true): + if self.prediction_fn is not None: + y_pred = self.prediction_fn(y_pred) + score = sklearn.metrics.average_precision_score(y_true, y_pred, average=self.average, labels=torch.unique(y_true)) + return torch.tensor(score) + + def worst(self, metrics): + return minimum(metrics) + class F1(Metric): def __init__(self, prediction_fn=logits_to_pred, name=None, average='binary'): self.prediction_fn = prediction_fn From b5fec57e4a33ff99a3ffac83ee2562fcc7d2d94f Mon Sep 17 00:00:00 2001 From: aikanor Date: Tue, 9 Feb 2021 11:50:07 -0800 Subject: [PATCH 014/244] using avg accuracy across balanced splits for now, until avg precision is tested --- examples/configs/datasets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index 824b6f3d..8eecbedd 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -95,7 +95,7 @@ 'eval_transform': None, 'loss_function': 'cross_entropy', 'groupby_fields': ['hospital'], - 'val_metric': 'ap', + 'val_metric': 'acc_avg', 'val_metric_decreasing': False, 'optimizer': 'Adam', # 'optimizer_kwargs': { }, @@ -107,7 +107,7 @@ 'n_groups_per_batch': 2, # 'irm_lambda': 1.0, # 'coral_penalty_weight': 0.1, - # 'algo_log_metric': 'accuracy', + 'algo_log_metric': 'accuracy' }, 'fmow': { 'split_scheme': 'official', From 31d67f585c098964543e7eae3b14ef31dc637473 Mon Sep 17 00:00:00 2001 From: aikanor Date: Tue, 9 Feb 2021 11:56:06 -0800 Subject: [PATCH 015/244] fix --- examples/configs/datasets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index 8eecbedd..7974159f 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -94,7 +94,7 @@ 'train_transform': None, 'eval_transform': None, 'loss_function': 'cross_entropy', - 'groupby_fields': ['hospital'], + 'groupby_fields': ['celltype'], 'val_metric': 'acc_avg', 'val_metric_decreasing': False, 'optimizer': 'Adam', @@ -105,9 +105,9 @@ 'weight_decay': 0.01, 'n_epochs': 1, 'n_groups_per_batch': 2, + 'algo_log_metric': 'accuracy', # 'irm_lambda': 1.0, # 'coral_penalty_weight': 0.1, - 'algo_log_metric': 'accuracy' }, 'fmow': { 'split_scheme': 'official', From a94feffbed064e8ffc0e7cab7eb8fa3303278b0d Mon Sep 17 00:00:00 2001 From: aikanor Date: Tue, 9 Feb 2021 15:56:50 -0800 Subject: [PATCH 016/244] integration 10/ --- .../encode-tfbs/prep_accessibility.py | 4 +- examples/configs/datasets.py | 24 +--- examples/sbox_run_expt.ipynb | 130 ++++++++---------- wilds/datasets/encodetfbs_dataset.py | 11 +- 4 files changed, 68 insertions(+), 101 deletions(-) diff --git a/dataset_preprocessing/encode-tfbs/prep_accessibility.py b/dataset_preprocessing/encode-tfbs/prep_accessibility.py index 7342f797..31bf872c 100644 --- a/dataset_preprocessing/encode-tfbs/prep_accessibility.py +++ b/dataset_preprocessing/encode-tfbs/prep_accessibility.py @@ -1,8 +1,6 @@ -import numpy, pandas +import numpy as np import pyBigWig -from tqdm import tqdm - # Human chromosome names chr_IDs = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX'] diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index 7974159f..1d15c7af 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -50,6 +50,8 @@ 'weight_decay': 0.01, 'n_epochs': 5, 'n_groups_per_batch': 2, + 'irm_lambda': 1.0, + 'coral_penalty_weight': 0.1, 'algo_log_metric': 'accuracy', }, 'celebA': { @@ -87,28 +89,6 @@ 'algo_log_metric': 'accuracy', 'max_token_length': 300, }, - 'encode-tfbs': { - 'split_scheme': 'official', - 'model': 'beagle', - 'model_kwargs': {'pretrained': False}, - 'train_transform': None, - 'eval_transform': None, - 'loss_function': 'cross_entropy', - 'groupby_fields': ['celltype'], - 'val_metric': 'acc_avg', - 'val_metric_decreasing': False, - 'optimizer': 'Adam', - # 'optimizer_kwargs': { }, - 'scheduler': None, - 'batch_size': 128, - 'lr': 0.001, - 'weight_decay': 0.01, - 'n_epochs': 1, - 'n_groups_per_batch': 2, - 'algo_log_metric': 'accuracy', - # 'irm_lambda': 1.0, - # 'coral_penalty_weight': 0.1, - }, 'fmow': { 'split_scheme': 'official', 'dataset_kwargs': { diff --git a/examples/sbox_run_expt.ipynb b/examples/sbox_run_expt.ipynb index 2fc71286..8b79c235 100644 --- a/examples/sbox_run_expt.ipynb +++ b/examples/sbox_run_expt.ipynb @@ -42,7 +42,7 @@ { "data": { "text/plain": [ - "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" + "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" ] }, "execution_count": 2, @@ -145,7 +145,7 @@ "outputs": [], "source": [ "argstr_camelyon = \"--dataset camelyon17 --algorithm ERM --root_dir data\"\n", - "argstr_encode = \"--dataset encode-tfbs --algorithm ERM --root_dir data\"" + "argstr_encode = \"--dataset encode-tfbs --algorithm ERM --root_dir data --download\"" ] }, { @@ -166,94 +166,72 @@ "metadata": {}, "outputs": [], "source": [ - "config = config_camelyon" + "#config = config_camelyon\n", + "config = config_encode" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "'./logs'" + "True" ] }, - "execution_count": 10, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "config.log_dir" + "config.download" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Dataset: camelyon17\n", - "Algorithm: ERM\n", - "Root dir: data\n", - "Split scheme: official\n", - "Dataset kwargs: {}\n", - "Download: False\n", - "Frac: 1.0\n", - "Loader kwargs: {'num_workers': 4, 'pin_memory': True}\n", - "Train loader: standard\n", - "Uniform over groups: False\n", - "Distinct groups: None\n", - "N groups per batch: 2\n", - "Batch size: 32\n", - "Eval loader: standard\n", - "Model: densenet121\n", - "Model kwargs: {'pretrained': False}\n", - "Train transform: image_base\n", - "Eval transform: image_base\n", - "Target resolution: (224, 224)\n", - "Resize scale: None\n", - "Max token length: None\n", - "Loss function: cross_entropy\n", - "Groupby fields: ['hospital']\n", - "Group dro step size: None\n", - "Coral penalty weight: None\n", - "Irm lambda: None\n", - "Irm penalty anneal iters: None\n", - "Algo log metric: accuracy\n", - "Val metric: acc_avg\n", - "Val metric decreasing: False\n", - "N epochs: 5\n", - "Optimizer: SGD\n", - "Lr: 0.001\n", - "Weight decay: 0.01\n", - "Max grad norm: None\n", - "Optimizer kwargs: {'momentum': 0.9}\n", - "Scheduler: None\n", - "Scheduler kwargs: {}\n", - "Scheduler metric split: val\n", - "Scheduler metric name: None\n", - "Evaluate all splits: True\n", - "Eval splits: []\n", - "Eval only: False\n", - "Eval epoch: None\n", - "Device: cuda:0\n", - "Seed: 0\n", - "Log dir: ./logs\n", - "Log every: 50\n", - "Save step: None\n", - "Save best: True\n", - "Save last: True\n", - "No group logging: False\n", - "Use wandb: False\n", - "Progress bar: False\n", - "Resume: False\n", - "\n" + "Downloading dataset to data/encode-tfbs_v1.0...\n", + "You can also download the dataset manually at https://wilds.stanford.edu/downloads.\n", + "Downloading https://worksheets.codalab.org/rest/bundles/0x8b3255e21e164cd98d3aeec09cd0bc26/contents/blob/ to data/encode-tfbs_v1.0/archive.tar.gz\n", + "\n", + "data/encode-tfbs_v1.0/archive.tar.gz may be corrupted. Please try deleting it and rerunning this command.\n", + "\n", + "Exception: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Exception ignored in: \n", + "Traceback (most recent call last):\n", + " File \"/users/abalsubr/anaconda2/envs/wilds1/lib/python3.8/site-packages/tqdm/std.py\", line 1134, in __del__\n", + " self.close()\n", + " File \"/users/abalsubr/anaconda2/envs/wilds1/lib/python3.8/site-packages/tqdm/notebook.py\", line 283, in close\n", + " self.disp(bar_style='success')\n", + "AttributeError: 'tqdm' object has no attribute 'disp'\n" + ] + }, + { + "ename": "FileNotFoundError", + "evalue": "[Errno 2] No such file or directory: 'data/encode-tfbs_v1.0/sequence.npz'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0;31m# Data\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 26\u001b[0;31m full_dataset = supported.datasets[config.dataset](\n\u001b[0m\u001b[1;32m 27\u001b[0m \u001b[0mroot_dir\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mroot_dir\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0mdownload\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdownload\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/wilds/wilds/datasets/encodetfbs_dataset.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, root_dir, download, split_scheme)\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[0;31m# Load sequence and DNase features\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[0msequence_filename\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_data_dir\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'sequence.npz'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 51\u001b[0;31m \u001b[0mseq_arr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msequence_filename\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 52\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_seq_bp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 53\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mchrom\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mseq_arr\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/numpy/lib/npyio.py\u001b[0m in \u001b[0;36mload\u001b[0;34m(file, mmap_mode, allow_pickle, fix_imports, encoding)\u001b[0m\n\u001b[1;32m 415\u001b[0m \u001b[0mown_fid\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 416\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 417\u001b[0;31m \u001b[0mfid\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mstack\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menter_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mos_fspath\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfile\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"rb\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 418\u001b[0m \u001b[0mown_fid\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 419\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'data/encode-tfbs_v1.0/sequence.npz'" ] } ], @@ -277,7 +255,7 @@ "logger = Logger(os.path.join(config.log_dir, 'log.txt'), mode)\n", "\n", "# Record config\n", - "log_config(config, logger)\n", + "# log_config(config, logger)\n", "\n", "# Set random seed\n", "set_seed(config.seed)\n", @@ -303,22 +281,32 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "import copy\n", + "full_dataset_camelyon17 = copy.deepcopy(full_dataset)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 7, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "full_dataset" + "full_dataset_camelyon17" ] }, { @@ -341,7 +329,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -350,7 +338,7 @@ "1" ] }, - "execution_count": 9, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index d26d052a..23f0b1d7 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -25,7 +25,8 @@ class EncodeTFBSDataset(WILDSDataset): """ def __init__(self, root_dir, download, split_scheme): - self._dataset_name = 'encodeTFBS' + self._dataset_name = 'encode-tfbs' + self._version = '1.0' self._download_url = 'https://worksheets.codalab.org/rest/bundles/0x8b3255e21e164cd98d3aeec09cd0bc26/contents/blob/' self._data_dir = self.initialize_data_dir(root_dir, download) self._y_size = 1 @@ -55,10 +56,10 @@ def __init__(self, root_dir, download, split_scheme): self._dnase_allcelltypes = {} for ct in self._all_celltypes: dnase_filename = os.path.join(self._data_dir, '{}_dnase.npz'.format(ct)) - dnase_npz_file = np.load(dnase_filename) + dnase_npz_contents = np.load(dnase_filename) self._dnase_allcelltypes[ct] = {} - for chrom in seq_bp: - self._dnase_allcelltypes[ct][chrom] = dnase_npz_file[chrom] + for chrom in self._seq_bp: + self._dnase_allcelltypes[ct][chrom] = dnase_npz_contents[chrom] # Read in metadata dataframe from training+validation data train_chr = pd.read_csv(os.path.join(self._data_dir, 'labels/{}.train.labels.tsv.gz'.format(self._transcription_factor)), sep='\t') @@ -130,7 +131,7 @@ def __init__(self, root_dir, download, split_scheme): self._eval_grouper = CombinatorialGrouper( dataset=self, groupby_fields=['celltype']) - self._metric = Auprc() + self._metric = Accuracy() super().__init__(root_dir, download, split_scheme) From cb9fc39e0caa334ba7512eae9eb1b964b934ad90 Mon Sep 17 00:00:00 2001 From: aikanor Date: Tue, 9 Feb 2021 19:00:21 -0800 Subject: [PATCH 017/244] integration 11/ --- .../encode-tfbs/prep_sequence.py | 214 ++++++++---------- examples/configs/datasets.py | 22 ++ 2 files changed, 116 insertions(+), 120 deletions(-) diff --git a/dataset_preprocessing/encode-tfbs/prep_sequence.py b/dataset_preprocessing/encode-tfbs/prep_sequence.py index 7f396d9f..7d6ede23 100644 --- a/dataset_preprocessing/encode-tfbs/prep_sequence.py +++ b/dataset_preprocessing/encode-tfbs/prep_sequence.py @@ -1,130 +1,104 @@ import argparse, time -import numpy, pandas +import numpy as np from tqdm import tqdm # Human chromosome names chr_IDs = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX'] -def one_hot_encode(sequence, ignore='N', alphabet=None, dtype='int8', - verbose=False, **kwargs): - """Converts a string or list of characters into a one-hot encoding. - This function will take in either a string or a list and convert it into a - one-hot encoding. If the input is a string, each character is assumed to be - a different symbol, e.g. 'ACGT' is assumed to be a sequence of four - characters. If the input is a list, the elements can be any size. - Although this function will be used here primarily to convert nucleotide - sequences into one-hot encoding with an alphabet of size 4, in principle - this function can be used for any types of sequences. - Parameters - ---------- - sequence : str or list - The sequence to convert to a one-hot encoding. - ignore : str, optional - A character to indicate setting nothing to 1 for that row, keeping the - encoding entirely 0's for that row. In the context of genomics, this is - the N character. Default is 'N'. - alphabet : set or tuple or list, optional - A pre-defined alphabet. If None is passed in, the alphabet will be - determined from the sequence, but this may be time consuming for - large sequences. Default is None. - dtype : str or numpy.dtype, optional - The data type of the returned encoding. Default is int8. - verbose : bool or str, optional - Whether to display a progress bar. If a string is passed in, use as the - name of the progressbar. Default is False. - kwargs : arguments - Arguments to be passed into tqdm. Default is None. - Returns - ------- - ohe : numpy.ndarray - A binary matrix of shape (alphabet_size, sequence_length) where - alphabet_size is the number of unique elements in the sequence and - sequence_length is the length of the input sequence. - """ - - name = None if verbose in (True, False) else verbose - d = verbose is False - - if isinstance(sequence, str): - sequence = list(sequence) - - alphabet = alphabet or numpy.unique(sequence) - alphabet = [char for char in alphabet if char != ignore] - alphabet_lookup = {char: i for i, char in enumerate(alphabet)} - - ohe = numpy.zeros((len(sequence), len(alphabet)), dtype=dtype) - for i, char in tqdm(enumerate(sequence), disable=d, desc=name, **kwargs): - if char != ignore: - idx = alphabet_lookup[char] - ohe[i, idx] = 1 - - return ohe - - -def read_fasta(filename, include_chroms=None, exclude_chroms=None, - ignore='N', alphabet=['A', 'C', 'G', 'T', 'N'], verbose=True): - """Read in a FASTA file and output a dictionary of sequences. - This function will take in the path to a FASTA-formatted file and output - a string containing the sequence for each chromosome. Optionally, - the user can specify a set of chromosomes to include or exclude from - the returned dictionary. - Parameters - ---------- - filename : str - The path to the FASTA-formatted file to open. - include_chroms : set or tuple or list, optional - The exact names of chromosomes in the FASTA file to include, excluding - all others. If None, include all chromosomes (except those specified by - exclude_chroms). Default is None. - exclude_chroms : set or tuple or list, optional - The exact names of chromosomes in the FASTA file to exclude, including - all others. If None, include all chromosomes (or the set specified by - include_chroms). Default is None. - ignore : str, optional - A character to indicate setting nothing to 1 for that row, keeping the - encoding entirely 0's for that row. In the context of genomics, this is - the N character. Default is 'N'. - alphabet : set or tuple or list, optional - A pre-defined alphabet. If None is passed in, the alphabet will be - determined from the sequence, but this may be time consuming for - large sequences. Must include the ignore character. Default is - ['A', 'C', 'G', 'T', 'N']. - verbose : bool or str, optional - Whether to display a progress bar. If a string is passed in, use as the - name of the progressbar. Default is False. - Returns - ------- - chroms : dict - A dictionary of strings where the keys are the names of the - chromosomes (exact strings from the header lines in the FASTA file) - and the values are the strings encoded there. - """ - - sequences = {} - name, sequence = None, None - skip_chrom = False - - with open(filename, "r") as infile: - for line in tqdm(infile, disable=not verbose): - if line.startswith(">"): - if name is not None and skip_chrom is False: - sequences[name] = ''.join(sequence) - - sequence = [] - name = line[1:].strip("\n") - if include_chroms is not None and name not in include_chroms: - skip_chrom = True - elif exclude_chroms is not None and name in exclude_chroms: - skip_chrom = True - else: - skip_chrom = False - - else: - if skip_chrom == False: - sequence.append(line.rstrip("\n").upper()) - - return sequences +def one_hot_encode(sequence, ignore='N', alphabet=None, dtype='int8', verbose=False, **kwargs): + """ + Converts a string or list of characters into a one-hot encoding. + This function will take in either a string or a list and convert it into a one-hot encoding. If the input is a string, each character is assumed to be a different symbol, e.g. 'ACGT' is assumed to be a sequence of four characters. If the input is a list, the elements can be any size. + Although this function will be used here primarily to convert nucleotide sequences into one-hot encoding with an alphabet of size 4, in principle this function can be used for any types of sequences. + + Parameters + ---------- + sequence : str or list + The sequence to convert to a one-hot encoding. + ignore : str, optional + A character to indicate setting nothing to 1 for that row, keeping the encoding entirely 0's for that row. In the context of genomics, this is the N character. Default is 'N'. + alphabet : set or tuple or list, optional + A pre-defined alphabet. If None is passed in, the alphabet will be determined from the sequence, but this may be time consuming for large sequences. Default is None. + dtype : str or numpy.dtype, optional + The data type of the returned encoding. Default is int8. + verbose : bool or str, optional + Whether to display a progress bar. If a string is passed in, use as the name of the progressbar. Default is False. + kwargs : arguments + Arguments to be passed into tqdm. Default is None. + + Returns + ------- + ohe : numpy.ndarray + A binary matrix of shape (alphabet_size, sequence_length) where alphabet_size is the number of unique elements in the sequence and sequence_length is the length of the input sequence. + """ + + name = None if verbose in (True, False) else verbose + d = verbose is False + + if isinstance(sequence, str): + sequence = list(sequence) + + alphabet = alphabet or np.unique(sequence) + alphabet = [char for char in alphabet if char != ignore] + alphabet_lookup = {char: i for i, char in enumerate(alphabet)} + + ohe = np.zeros((len(sequence), len(alphabet)), dtype=dtype) + for i, char in tqdm(enumerate(sequence), disable=d, desc=name, **kwargs): + if char != ignore: + idx = alphabet_lookup[char] + ohe[i, idx] = 1 + + return ohe + + +def read_fasta(filename, include_chroms=None, exclude_chroms=None, ignore='N', alphabet=['A', 'C', 'G', 'T', 'N'], verbose=True): + """ + Read in a FASTA file and output a dictionary of sequences. + This function will take in the path to a FASTA-formatted file and output a string containing the sequence for each chromosome. Optionally, the user can specify a set of chromosomes to include or exclude from the returned dictionary. + + Parameters + ---------- + filename : str + The path to the FASTA-formatted file to open. + include_chroms : set or tuple or list, optional + The exact names of chromosomes in the FASTA file to include, excluding all others. If None, include all chromosomes (except those specified by exclude_chroms). Default is None. + exclude_chroms : set or tuple or list, optional + The exact names of chromosomes in the FASTA file to exclude, including all others. If None, include all chromosomes (or the set specified by include_chroms). Default is None. + ignore : str, optional + A character to indicate setting nothing to 1 for that row, keeping the encoding entirely 0's for that row. In the context of genomics, this is the N character. Default is 'N'. + alphabet : set or tuple or list, optional + A pre-defined alphabet. If None is passed in, the alphabet will be determined from the sequence, but this may be time consuming for large sequences. Must include the ignore character. Default is ['A', 'C', 'G', 'T', 'N']. + verbose : bool or str, optional + Whether to display a progress bar. If a string is passed in, use as the name of the progressbar. Default is False. + + Returns + ------- + chroms : dict + A dictionary of strings where the keys are the names of the chromosomes (exact strings from the header lines in the FASTA file) and the values are the strings encoded there. + """ + + sequences = {} + name, sequence = None, None + skip_chrom = False + + with open(filename, "r") as infile: + for line in tqdm(infile, disable=not verbose): + if line.startswith(">"): + if name is not None and skip_chrom is False: + sequences[name] = ''.join(sequence) + sequence = [] + name = line[1:].strip("\n") + if include_chroms is not None and name not in include_chroms: + skip_chrom = True + elif exclude_chroms is not None and name in exclude_chroms: + skip_chrom = True + else: + skip_chrom = False + else: + if skip_chrom == False: + sequence.append(line.rstrip("\n").upper()) + return sequences def generate_sequence_archive(seq_path='sequence/hg19.genome.fa', output_dir): diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index 1d15c7af..58baf1c1 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -89,6 +89,28 @@ 'algo_log_metric': 'accuracy', 'max_token_length': 300, }, + 'encode-tfbs': { + 'split_scheme': 'official', + 'model': 'beagle', + 'model_kwargs': {'pretrained': False}, + 'train_transform': None, + 'eval_transform': None, + 'loss_function': 'cross_entropy', + 'groupby_fields': ['celltype'], + 'val_metric': 'acc_avg', + 'val_metric_decreasing': False, + 'optimizer': 'Adam', + # 'optimizer_kwargs': { }, + 'scheduler': None, + 'batch_size': 128, + 'lr': 0.001, + 'weight_decay': 0.01, + 'n_epochs': 1, + 'n_groups_per_batch': 2, + 'algo_log_metric': 'accuracy', + # 'irm_lambda': 1.0, + # 'coral_penalty_weight': 0.1, + }, 'fmow': { 'split_scheme': 'official', 'dataset_kwargs': { From 1a7675aab18938578f99974f235b5cf110e40d9e Mon Sep 17 00:00:00 2001 From: aikanor Date: Wed, 10 Feb 2021 00:18:56 -0800 Subject: [PATCH 018/244] integration 12/ --- dataset_preprocessing/encode-tfbs/README.md | 2 +- .../encode-tfbs/prep_accessibility.py | 12 +- .../encode-tfbs/prep_datasets.ipynb | 12 +- .../encode-tfbs/prep_sequence.py | 4 +- examples/models/CNN_genome.py | 15 +- examples/run_expt.py | 556 +++++++++--------- sandbox_data.ipynb | 9 +- wilds/common/metrics/all_metrics.py | 2 +- 8 files changed, 304 insertions(+), 308 deletions(-) diff --git a/dataset_preprocessing/encode-tfbs/README.md b/dataset_preprocessing/encode-tfbs/README.md index 0be5fbd6..616d4cb5 100644 --- a/dataset_preprocessing/encode-tfbs/README.md +++ b/dataset_preprocessing/encode-tfbs/README.md @@ -5,7 +5,7 @@ #### Instructions -1. Download the human genome sequence (hg19 assembly) in FASTA format from http://hgdownload.cse.ucsc.edu/goldenpath/hg19/bigZips/hg19.fa.gz into `SEQUENCE_PATH`. +1. Download the human genome sequence (hg19 assembly) in FASTA format from http://hgdownload.cse.ucsc.edu/goldenpath/hg19/bigZips/hg19.fa.gz and extract it into `SEQUENCE_PATH`. 2. Run `python prep_sequence.py --seq_path SEQUENCE_PATH --output_dir OUTPUT_DIR` to write the fasta file found in `SEQUENCE_PATH` to a numpy array archive in `OUTPUT_DIR`. diff --git a/dataset_preprocessing/encode-tfbs/prep_accessibility.py b/dataset_preprocessing/encode-tfbs/prep_accessibility.py index 31bf872c..141981c0 100644 --- a/dataset_preprocessing/encode-tfbs/prep_accessibility.py +++ b/dataset_preprocessing/encode-tfbs/prep_accessibility.py @@ -1,24 +1,24 @@ +import argparse, time import numpy as np import pyBigWig # Human chromosome names chr_IDs = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX'] -def generate_accessibility_archives(input_dir, output_dir): +def generate_accessibility_archives(input_dir='dnase_bigwigs', output_dir='codalab_archive'): dnases = {} celltypes = ['A549', 'GM12878', 'H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562'] - for ctype in celltypes:#glob.glob('dnase_bigwigs/*'): + for ctype in celltypes: itime = time.time() - # ctype = pth.split('/')[1].split('.')[1] bw = pyBigWig.open("{}/DNASE.{}.fc.signal.bigwig".format(input_dir, ctype)) chromsizes = bw.chroms() - print(ctype, time.time() - itime) dn_dict = {} for chrom in chromsizes: #chr_IDs: x = bw.values(chrom, 0, chromsizes[chrom], numpy=True) - dn_dict[chrom] = np.nan_to_num(x).astype(np.float16) # half-precision makes things significantly smaller (less time to load) - print(chrom, time.time() - itime) + # half-precision makes things significantly smaller (less time to load) + dn_dict[chrom] = np.nan_to_num(x).astype(np.float16) + print("{}, {}. Time: {}".format(ctype, chrom, time.time() - itime)) dnases[ctype] = dn_dict for ctype in dnases: diff --git a/dataset_preprocessing/encode-tfbs/prep_datasets.ipynb b/dataset_preprocessing/encode-tfbs/prep_datasets.ipynb index 4b1fdc10..78235fd7 100644 --- a/dataset_preprocessing/encode-tfbs/prep_datasets.ipynb +++ b/dataset_preprocessing/encode-tfbs/prep_datasets.ipynb @@ -257,23 +257,23 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 2", + "display_name": "Python 3", "language": "python", - "name": "python2" + "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", - "version": 2 + "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.13" + "pygments_lexer": "ipython3", + "version": "3.8.5" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/dataset_preprocessing/encode-tfbs/prep_sequence.py b/dataset_preprocessing/encode-tfbs/prep_sequence.py index 7d6ede23..3ead9a27 100644 --- a/dataset_preprocessing/encode-tfbs/prep_sequence.py +++ b/dataset_preprocessing/encode-tfbs/prep_sequence.py @@ -3,6 +3,8 @@ from tqdm import tqdm +# Sequence preprocessing. Code adapted from Jacob Schreiber. + # Human chromosome names chr_IDs = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX'] @@ -32,7 +34,7 @@ def one_hot_encode(sequence, ignore='N', alphabet=None, dtype='int8', verbose=Fa ohe : numpy.ndarray A binary matrix of shape (alphabet_size, sequence_length) where alphabet_size is the number of unique elements in the sequence and sequence_length is the length of the input sequence. """ - + name = None if verbose in (True, False) else verbose d = verbose is False diff --git a/examples/models/CNN_genome.py b/examples/models/CNN_genome.py index 8a658eab..b0743960 100644 --- a/examples/models/CNN_genome.py +++ b/examples/models/CNN_genome.py @@ -6,23 +6,14 @@ class Beagle(nn.Module): """ - Neural net models over genomic sequence. + Neural net models over genomic sequence. Adapted from https://github.com/kundajelab/ChromDragoNN Input: - - sequence_length: int (default 1000) - - Shape: (N, 5, sequence_length, 1) with batch size N. + - s (Tensor): float torch tensor of shape (N, 5, 1000, 1) with batch size N. Output: - prediction (Tensor): float torch tensor of shape (N, ) - - TODO: Finish docstring. """ def __init__(self): - """ - Parameters - ---------- - sequence_length : int - n_genomic_features : int - """ super(Beagle, self).__init__() self.dropout = 0.3 @@ -57,6 +48,6 @@ def forward(self, s): s = F.dropout(F.relu(self.bn4(self.fc1(s))), p=self.dropout, training=self.training) # batch_size x 1000 s = F.dropout(F.relu(self.bn5(self.fc2(s))), p=self.dropout, training=self.training) # batch_size x 1000 - s = self.fc3(s) + prediction = self.fc3(s) return s#, conv_out diff --git a/examples/run_expt.py b/examples/run_expt.py index 710157c3..166df04f 100644 --- a/examples/run_expt.py +++ b/examples/run_expt.py @@ -1,278 +1,278 @@ -import os, csv -import time -import argparse -import pandas as pd -import torch -import torch.nn as nn -import torchvision -import sys -from collections import defaultdict - -from wilds.common.data_loaders import get_train_loader, get_eval_loader -from wilds.common.grouper import CombinatorialGrouper - -from utils import set_seed, Logger, BatchLogger, log_config, ParseKwargs, load, initialize_wandb, log_group_data, parse_bool -from train import train, evaluate -from algorithms.initializer import initialize_algorithm -from transforms import initialize_transform -from configs.utils import populate_defaults -import configs.supported as supported - -def main(): - ''' set default hyperparams in default_hyperparams.py ''' - parser = argparse.ArgumentParser() - - # Required arguments - parser.add_argument('-d', '--dataset', choices=supported.datasets, required=True) - parser.add_argument('--algorithm', required=True, choices=supported.algorithms) - parser.add_argument('--root_dir', required=True, - help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).') - - # Dataset - parser.add_argument('--split_scheme', help='Identifies how the train/val/test split is constructed. Choices are dataset-specific.') - parser.add_argument('--dataset_kwargs', nargs='*', action=ParseKwargs, default={}) - parser.add_argument('--download', default=False, type=parse_bool, const=True, nargs='?', - help='If true, tries to downloads the dataset if it does not exist in root_dir.') - parser.add_argument('--frac', type=float, default=1.0, - help='Convenience parameter that scales all dataset splits down to the specified fraction, for development purposes.') - - # Loaders - parser.add_argument('--loader_kwargs', nargs='*', action=ParseKwargs, default={}) - parser.add_argument('--train_loader', choices=['standard', 'group']) - parser.add_argument('--uniform_over_groups', type=parse_bool, const=True, nargs='?') - parser.add_argument('--distinct_groups', type=parse_bool, const=True, nargs='?') - parser.add_argument('--n_groups_per_batch', type=int) - parser.add_argument('--batch_size', type=int) - parser.add_argument('--eval_loader', choices=['standard'], default='standard') - - # Model - parser.add_argument('--model', choices=supported.models) - parser.add_argument('--model_kwargs', nargs='*', action=ParseKwargs, default={}, - help='keyword arguments for model initialization passed as key1=value1 key2=value2') - - # Transforms - parser.add_argument('--train_transform', choices=supported.transforms) - parser.add_argument('--eval_transform', choices=supported.transforms) - parser.add_argument('--target_resolution', nargs='+', type=int, help='target resolution. for example --target_resolution 224 224 for standard resnet.') - parser.add_argument('--resize_scale', type=float) - parser.add_argument('--max_token_length', type=int) - - # Objective - parser.add_argument('--loss_function', choices = supported.losses) - - # Algorithm - parser.add_argument('--groupby_fields', nargs='+') - parser.add_argument('--group_dro_step_size', type=float) - parser.add_argument('--coral_penalty_weight', type=float) - parser.add_argument('--irm_lambda', type=float) - parser.add_argument('--irm_penalty_anneal_iters', type=int) - parser.add_argument('--algo_log_metric') - - # Model selection - parser.add_argument('--val_metric') - parser.add_argument('--val_metric_decreasing', type=parse_bool, const=True, nargs='?') - - # Optimization - parser.add_argument('--n_epochs', type=int) - parser.add_argument('--optimizer', choices=supported.optimizers) - parser.add_argument('--lr', type=float) - parser.add_argument('--weight_decay', type=float) - parser.add_argument('--max_grad_norm', type=float) - parser.add_argument('--optimizer_kwargs', nargs='*', action=ParseKwargs, default={}) - - # Scheduler - parser.add_argument('--scheduler', choices=supported.schedulers) - parser.add_argument('--scheduler_kwargs', nargs='*', action=ParseKwargs, default={}) - parser.add_argument('--scheduler_metric_split', choices=['train', 'val'], default='val') - parser.add_argument('--scheduler_metric_name') - - # Evaluation - parser.add_argument('--evaluate_all_splits', type=parse_bool, const=True, nargs='?', default=True) - parser.add_argument('--eval_splits', nargs='+', default=[]) - parser.add_argument('--eval_only', type=parse_bool, const=True, nargs='?', default=False) - parser.add_argument('--eval_epoch', default=None, type=int) - - # Misc - parser.add_argument('--device', type=int, default=0) - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--log_dir', default='./logs') - parser.add_argument('--log_every', default=50, type=int) - parser.add_argument('--save_step', type=int) - parser.add_argument('--save_best', type=parse_bool, const=True, nargs='?', default=True) - parser.add_argument('--save_last', type=parse_bool, const=True, nargs='?', default=True) - parser.add_argument('--no_group_logging', type=parse_bool, const=True, nargs='?') - parser.add_argument('--use_wandb', type=parse_bool, const=True, nargs='?', default=False) - parser.add_argument('--progress_bar', type=parse_bool, const=True, nargs='?', default=False) - parser.add_argument('--resume', type=parse_bool, const=True, nargs='?', default=False) - - config = parser.parse_args() - config = populate_defaults(config) - - # set device - config.device = torch.device("cuda:" + str(config.device)) if torch.cuda.is_available() else torch.device("cpu") - - ## Initialize logs - if os.path.exists(config.log_dir) and config.resume: - resume=True - mode='a' - elif os.path.exists(config.log_dir) and config.eval_only: - resume=False - mode='a' - else: - resume=False - mode='w' - - if not os.path.exists(config.log_dir): - os.makedirs(config.log_dir) - logger = Logger(os.path.join(config.log_dir, 'log.txt'), mode) - - # Record config - log_config(config, logger) - - # Set random seed - set_seed(config.seed) - - # Data - full_dataset = supported.datasets[config.dataset]( - root_dir=config.root_dir, - download=config.download, - split_scheme=config.split_scheme, - **config.dataset_kwargs) - - # To implement data augmentation (i.e., have different transforms - # at training time vs. test time), modify these two lines: - train_transform = initialize_transform( - transform_name=config.train_transform, - config=config, - dataset=full_dataset) - eval_transform = initialize_transform( - transform_name=config.eval_transform, - config=config, - dataset=full_dataset) - - train_grouper = CombinatorialGrouper( - dataset=full_dataset, - groupby_fields=config.groupby_fields) - - datasets = defaultdict(dict) - for split in full_dataset.split_dict.keys(): - if split=='train': - transform = train_transform - verbose = True - elif split == 'val': - transform = eval_transform - verbose = True - else: - transform = eval_transform - verbose = False - # Get subset - datasets[split]['dataset'] = full_dataset.get_subset( - split, - frac=config.frac, - transform=transform) - - if split == 'train': - datasets[split]['loader'] = get_train_loader( - loader=config.train_loader, - dataset=datasets[split]['dataset'], - batch_size=config.batch_size, - uniform_over_groups=config.uniform_over_groups, - grouper=train_grouper, - distinct_groups=config.distinct_groups, - n_groups_per_batch=config.n_groups_per_batch, - **config.loader_kwargs) - else: - datasets[split]['loader'] = get_eval_loader( - loader=config.eval_loader, - dataset=datasets[split]['dataset'], - grouper=train_grouper, - batch_size=config.batch_size, - **config.loader_kwargs) - - # Set fields - datasets[split]['split'] = split - datasets[split]['name'] = full_dataset.split_names[split] - datasets[split]['verbose'] = verbose - # Loggers - # Loggers - datasets[split]['eval_logger'] = BatchLogger( - os.path.join(config.log_dir, f'{split}_eval.csv'), mode=mode, use_wandb=(config.use_wandb and verbose)) - datasets[split]['algo_logger'] = BatchLogger( - os.path.join(config.log_dir, f'{split}_algo.csv'), mode=mode, use_wandb=(config.use_wandb and verbose)) - - if config.use_wandb: - initialize_wandb(config) - - # Logging dataset info - if config.no_group_logging and full_dataset.is_classification and full_dataset.y_size==1: - log_grouper = CombinatorialGrouper( - dataset=full_dataset, - groupby_fields=['y']) - elif config.no_group_logging: - log_grouper = None - else: - log_grouper = train_grouper - log_group_data(datasets, log_grouper, logger) - - ## Initialize algorithm - algorithm = initialize_algorithm( - config=config, - datasets=datasets, - train_grouper=train_grouper) - - if not config.eval_only: - ## Load saved results if resuming - resume_success = False - if resume: - save_path = os.path.join(config.log_dir, 'last_model.pth') - if not os.path.exists(save_path): - epochs = [ - int(file.split('_')[0]) - for file in os.listdir(config.log_dir) if file.endswith('.pth')] - if len(epochs) > 0: - latest_epoch = max(epochs) - save_path = os.path.join(config.log_dir, f'{latest_epoch}_model.pth') - try: - prev_epoch, best_val_metric = load(algorithm, save_path) - epoch_offset = prev_epoch + 1 - logger.write(f'Resuming from epoch {epoch_offset} with best val metric {best_val_metric}') - resume_success = True - except FileNotFoundError: - pass - - if resume_success == False: - epoch_offset=0 - best_val_metric=None - - - train( - algorithm=algorithm, - datasets=datasets, - general_logger=logger, - config=config, - epoch_offset=epoch_offset, - best_val_metric=best_val_metric) - else: - if config.eval_epoch is None: - eval_model_path = os.path.join(config.log_dir, 'best_model.pth') - else: - eval_model_path = os.path.join(config.log_dir, f'{config.eval_epoch}_model.pth') - best_epoch, best_val_metric = load(algorithm, eval_model_path) - if config.eval_epoch is None: - epoch = best_epoch - else: - epoch = config.eval_epoch - evaluate( - algorithm=algorithm, - datasets=datasets, - epoch=epoch, - general_logger=logger, - config=config) - - logger.close() - for split in datasets: - datasets[split]['eval_logger'].close() - datasets[split]['algo_logger'].close() - -if __name__=='__main__': - main() +import os, csv +import time +import argparse +import pandas as pd +import torch +import torch.nn as nn +import torchvision +import sys +from collections import defaultdict + +from wilds.common.data_loaders import get_train_loader, get_eval_loader +from wilds.common.grouper import CombinatorialGrouper + +from utils import set_seed, Logger, BatchLogger, log_config, ParseKwargs, load, initialize_wandb, log_group_data, parse_bool +from train import train, evaluate +from algorithms.initializer import initialize_algorithm +from transforms import initialize_transform +from configs.utils import populate_defaults +import configs.supported as supported + +def main(): + ''' set default hyperparams in default_hyperparams.py ''' + parser = argparse.ArgumentParser() + + # Required arguments + parser.add_argument('-d', '--dataset', choices=supported.datasets, required=True) + parser.add_argument('--algorithm', required=True, choices=supported.algorithms) + parser.add_argument('--root_dir', required=True, + help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).') + + # Dataset + parser.add_argument('--split_scheme', help='Identifies how the train/val/test split is constructed. Choices are dataset-specific.') + parser.add_argument('--dataset_kwargs', nargs='*', action=ParseKwargs, default={}) + parser.add_argument('--download', default=False, type=parse_bool, const=True, nargs='?', + help='If true, tries to downloads the dataset if it does not exist in root_dir.') + parser.add_argument('--frac', type=float, default=1.0, + help='Convenience parameter that scales all dataset splits down to the specified fraction, for development purposes.') + + # Loaders + parser.add_argument('--loader_kwargs', nargs='*', action=ParseKwargs, default={}) + parser.add_argument('--train_loader', choices=['standard', 'group']) + parser.add_argument('--uniform_over_groups', type=parse_bool, const=True, nargs='?') + parser.add_argument('--distinct_groups', type=parse_bool, const=True, nargs='?') + parser.add_argument('--n_groups_per_batch', type=int) + parser.add_argument('--batch_size', type=int) + parser.add_argument('--eval_loader', choices=['standard'], default='standard') + + # Model + parser.add_argument('--model', choices=supported.models) + parser.add_argument('--model_kwargs', nargs='*', action=ParseKwargs, default={}, + help='keyword arguments for model initialization passed as key1=value1 key2=value2') + + # Transforms + parser.add_argument('--train_transform', choices=supported.transforms) + parser.add_argument('--eval_transform', choices=supported.transforms) + parser.add_argument('--target_resolution', nargs='+', type=int, help='target resolution. for example --target_resolution 224 224 for standard resnet.') + parser.add_argument('--resize_scale', type=float) + parser.add_argument('--max_token_length', type=int) + + # Objective + parser.add_argument('--loss_function', choices = supported.losses) + + # Algorithm + parser.add_argument('--groupby_fields', nargs='+') + parser.add_argument('--group_dro_step_size', type=float) + parser.add_argument('--coral_penalty_weight', type=float) + parser.add_argument('--irm_lambda', type=float) + parser.add_argument('--irm_penalty_anneal_iters', type=int) + parser.add_argument('--algo_log_metric') + + # Model selection + parser.add_argument('--val_metric') + parser.add_argument('--val_metric_decreasing', type=parse_bool, const=True, nargs='?') + + # Optimization + parser.add_argument('--n_epochs', type=int) + parser.add_argument('--optimizer', choices=supported.optimizers) + parser.add_argument('--lr', type=float) + parser.add_argument('--weight_decay', type=float) + parser.add_argument('--max_grad_norm', type=float) + parser.add_argument('--optimizer_kwargs', nargs='*', action=ParseKwargs, default={}) + + # Scheduler + parser.add_argument('--scheduler', choices=supported.schedulers) + parser.add_argument('--scheduler_kwargs', nargs='*', action=ParseKwargs, default={}) + parser.add_argument('--scheduler_metric_split', choices=['train', 'val'], default='val') + parser.add_argument('--scheduler_metric_name') + + # Evaluation + parser.add_argument('--evaluate_all_splits', type=parse_bool, const=True, nargs='?', default=True) + parser.add_argument('--eval_splits', nargs='+', default=[]) + parser.add_argument('--eval_only', type=parse_bool, const=True, nargs='?', default=False) + parser.add_argument('--eval_epoch', default=None, type=int) + + # Misc + parser.add_argument('--device', type=int, default=0) + parser.add_argument('--seed', type=int, default=0) + parser.add_argument('--log_dir', default='./logs') + parser.add_argument('--log_every', default=50, type=int) + parser.add_argument('--save_step', type=int) + parser.add_argument('--save_best', type=parse_bool, const=True, nargs='?', default=True) + parser.add_argument('--save_last', type=parse_bool, const=True, nargs='?', default=True) + parser.add_argument('--no_group_logging', type=parse_bool, const=True, nargs='?') + parser.add_argument('--use_wandb', type=parse_bool, const=True, nargs='?', default=False) + parser.add_argument('--progress_bar', type=parse_bool, const=True, nargs='?', default=False) + parser.add_argument('--resume', type=parse_bool, const=True, nargs='?', default=False) + + config = parser.parse_args() + config = populate_defaults(config) + + # set device + config.device = torch.device("cuda:" + str(config.device)) if torch.cuda.is_available() else torch.device("cpu") + + ## Initialize logs + if os.path.exists(config.log_dir) and config.resume: + resume=True + mode='a' + elif os.path.exists(config.log_dir) and config.eval_only: + resume=False + mode='a' + else: + resume=False + mode='w' + + if not os.path.exists(config.log_dir): + os.makedirs(config.log_dir) + logger = Logger(os.path.join(config.log_dir, 'log.txt'), mode) + + # Record config + log_config(config, logger) + + # Set random seed + set_seed(config.seed) + + # Data + full_dataset = supported.datasets[config.dataset]( + root_dir=config.root_dir, + download=config.download, + split_scheme=config.split_scheme, + **config.dataset_kwargs) + + # To implement data augmentation (i.e., have different transforms + # at training time vs. test time), modify these two lines: + train_transform = initialize_transform( + transform_name=config.train_transform, + config=config, + dataset=full_dataset) + eval_transform = initialize_transform( + transform_name=config.eval_transform, + config=config, + dataset=full_dataset) + + train_grouper = CombinatorialGrouper( + dataset=full_dataset, + groupby_fields=config.groupby_fields) + + datasets = defaultdict(dict) + for split in full_dataset.split_dict.keys(): + if split=='train': + transform = train_transform + verbose = True + elif split == 'val': + transform = eval_transform + verbose = True + else: + transform = eval_transform + verbose = False + # Get subset + datasets[split]['dataset'] = full_dataset.get_subset( + split, + frac=config.frac, + transform=transform) + + if split == 'train': + datasets[split]['loader'] = get_train_loader( + loader=config.train_loader, + dataset=datasets[split]['dataset'], + batch_size=config.batch_size, + uniform_over_groups=config.uniform_over_groups, + grouper=train_grouper, + distinct_groups=config.distinct_groups, + n_groups_per_batch=config.n_groups_per_batch, + **config.loader_kwargs) + else: + datasets[split]['loader'] = get_eval_loader( + loader=config.eval_loader, + dataset=datasets[split]['dataset'], + grouper=train_grouper, + batch_size=config.batch_size, + **config.loader_kwargs) + + # Set fields + datasets[split]['split'] = split + datasets[split]['name'] = full_dataset.split_names[split] + datasets[split]['verbose'] = verbose + # Loggers + # Loggers + datasets[split]['eval_logger'] = BatchLogger( + os.path.join(config.log_dir, f'{split}_eval.csv'), mode=mode, use_wandb=(config.use_wandb and verbose)) + datasets[split]['algo_logger'] = BatchLogger( + os.path.join(config.log_dir, f'{split}_algo.csv'), mode=mode, use_wandb=(config.use_wandb and verbose)) + + if config.use_wandb: + initialize_wandb(config) + + # Logging dataset info + if config.no_group_logging and full_dataset.is_classification and full_dataset.y_size==1: + log_grouper = CombinatorialGrouper( + dataset=full_dataset, + groupby_fields=['y']) + elif config.no_group_logging: + log_grouper = None + else: + log_grouper = train_grouper + log_group_data(datasets, log_grouper, logger) + + ## Initialize algorithm + algorithm = initialize_algorithm( + config=config, + datasets=datasets, + train_grouper=train_grouper) + + if not config.eval_only: + ## Load saved results if resuming + resume_success = False + if resume: + save_path = os.path.join(config.log_dir, 'last_model.pth') + if not os.path.exists(save_path): + epochs = [ + int(file.split('_')[0]) + for file in os.listdir(config.log_dir) if file.endswith('.pth')] + if len(epochs) > 0: + latest_epoch = max(epochs) + save_path = os.path.join(config.log_dir, f'{latest_epoch}_model.pth') + try: + prev_epoch, best_val_metric = load(algorithm, save_path) + epoch_offset = prev_epoch + 1 + logger.write(f'Resuming from epoch {epoch_offset} with best val metric {best_val_metric}') + resume_success = True + except FileNotFoundError: + pass + + if resume_success == False: + epoch_offset=0 + best_val_metric=None + + + train( + algorithm=algorithm, + datasets=datasets, + general_logger=logger, + config=config, + epoch_offset=epoch_offset, + best_val_metric=best_val_metric) + else: + if config.eval_epoch is None: + eval_model_path = os.path.join(config.log_dir, 'best_model.pth') + else: + eval_model_path = os.path.join(config.log_dir, f'{config.eval_epoch}_model.pth') + best_epoch, best_val_metric = load(algorithm, eval_model_path) + if config.eval_epoch is None: + epoch = best_epoch + else: + epoch = config.eval_epoch + evaluate( + algorithm=algorithm, + datasets=datasets, + epoch=epoch, + general_logger=logger, + config=config) + + logger.close() + for split in datasets: + datasets[split]['eval_logger'].close() + datasets[split]['algo_logger'].close() + +if __name__=='__main__': + main() diff --git a/sandbox_data.ipynb b/sandbox_data.ipynb index 681c5ec2..a348d6ea 100644 --- a/sandbox_data.ipynb +++ b/sandbox_data.ipynb @@ -18,7 +18,7 @@ " - [x] datasets/encodetfbs_dataset.py\n", " - common\n", " - metrics\n", - " - [ ] all_metrics.py\n", + " - [x] all_metrics.py\n", " - data_loaders.py\n", " - grouper.py\n", " - [ ] utils.py ( threshold_at_recall() )" @@ -30,9 +30,12 @@ "source": [ "# TODOs\n", "\n", - "- change evaluation metric\n", + "- [ ] change evaluation/validation metric\n", + " - \n", "\n", - "- change sequence length of model\n", + "- [ ] Citation/license for wilds/datasets/encodetfbs_dataset.py\n", + "\n", + "- (optional) change sequence length of model\n", " - examples/configs/model.py\n", " - examples/models/CNN_genome.py" ] diff --git a/wilds/common/metrics/all_metrics.py b/wilds/common/metrics/all_metrics.py index 85506eab..3330243e 100644 --- a/wilds/common/metrics/all_metrics.py +++ b/wilds/common/metrics/all_metrics.py @@ -81,7 +81,7 @@ def worst(self, metrics): return minimum(metrics) class AveragePrecision(Metric): - def __init__(self, prediction_fn=logits_to_pred, name=None, average='weighted'): + def __init__(self, prediction_fn=logits_to_pred, name=None, average='macro'): self.prediction_fn = prediction_fn if name is None: name = f'avgprec' From 82bbe483f28fda38666ff4ae5999e5e77b15f962 Mon Sep 17 00:00:00 2001 From: aikanor Date: Wed, 10 Feb 2021 01:11:25 -0800 Subject: [PATCH 019/244] integration 13/ --- examples/models/CNN_genome.py | 2 +- examples/sbox_run_expt.ipynb | 115 ++++++++------------------- sandbox_data.ipynb | 16 ++-- wilds/datasets/encodetfbs_dataset.py | 23 +++--- 4 files changed, 53 insertions(+), 103 deletions(-) diff --git a/examples/models/CNN_genome.py b/examples/models/CNN_genome.py index b0743960..cc464ef0 100644 --- a/examples/models/CNN_genome.py +++ b/examples/models/CNN_genome.py @@ -50,4 +50,4 @@ def forward(self, s): prediction = self.fc3(s) - return s#, conv_out + return s #, conv_out diff --git a/examples/sbox_run_expt.ipynb b/examples/sbox_run_expt.ipynb index 8b79c235..f6c6246e 100644 --- a/examples/sbox_run_expt.ipynb +++ b/examples/sbox_run_expt.ipynb @@ -42,7 +42,7 @@ { "data": { "text/plain": [ - "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" + "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" ] }, "execution_count": 2, @@ -145,7 +145,7 @@ "outputs": [], "source": [ "argstr_camelyon = \"--dataset camelyon17 --algorithm ERM --root_dir data\"\n", - "argstr_encode = \"--dataset encode-tfbs --algorithm ERM --root_dir data --download\"" + "argstr_encode = \"--dataset encode-tfbs --algorithm ERM --root_dir data\"" ] }, { @@ -160,81 +160,21 @@ "config_encode = populate_defaults(config_encode)" ] }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "#config = config_camelyon\n", - "config = config_encode" - ] - }, { "cell_type": "code", "execution_count": 6, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "config.download" + "config = config_camelyon\n", + "#config = config_encode" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Downloading dataset to data/encode-tfbs_v1.0...\n", - "You can also download the dataset manually at https://wilds.stanford.edu/downloads.\n", - "Downloading https://worksheets.codalab.org/rest/bundles/0x8b3255e21e164cd98d3aeec09cd0bc26/contents/blob/ to data/encode-tfbs_v1.0/archive.tar.gz\n", - "\n", - "data/encode-tfbs_v1.0/archive.tar.gz may be corrupted. Please try deleting it and rerunning this command.\n", - "\n", - "Exception: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Exception ignored in: \n", - "Traceback (most recent call last):\n", - " File \"/users/abalsubr/anaconda2/envs/wilds1/lib/python3.8/site-packages/tqdm/std.py\", line 1134, in __del__\n", - " self.close()\n", - " File \"/users/abalsubr/anaconda2/envs/wilds1/lib/python3.8/site-packages/tqdm/notebook.py\", line 283, in close\n", - " self.disp(bar_style='success')\n", - "AttributeError: 'tqdm' object has no attribute 'disp'\n" - ] - }, - { - "ename": "FileNotFoundError", - "evalue": "[Errno 2] No such file or directory: 'data/encode-tfbs_v1.0/sequence.npz'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0;31m# Data\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 26\u001b[0;31m full_dataset = supported.datasets[config.dataset](\n\u001b[0m\u001b[1;32m 27\u001b[0m \u001b[0mroot_dir\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mroot_dir\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0mdownload\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdownload\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/wilds/wilds/datasets/encodetfbs_dataset.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, root_dir, download, split_scheme)\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[0;31m# Load sequence and DNase features\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[0msequence_filename\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_data_dir\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'sequence.npz'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 51\u001b[0;31m \u001b[0mseq_arr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msequence_filename\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 52\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_seq_bp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 53\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mchrom\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mseq_arr\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/numpy/lib/npyio.py\u001b[0m in \u001b[0;36mload\u001b[0;34m(file, mmap_mode, allow_pickle, fix_imports, encoding)\u001b[0m\n\u001b[1;32m 415\u001b[0m \u001b[0mown_fid\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 416\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 417\u001b[0;31m \u001b[0mfid\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mstack\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menter_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mos_fspath\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfile\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"rb\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 418\u001b[0m \u001b[0mown_fid\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 419\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'data/encode-tfbs_v1.0/sequence.npz'" - ] - } - ], + "outputs": [], "source": [ "# set device\n", "config.device = torch.device(\"cuda:\" + str(config.device)) if torch.cuda.is_available() else torch.device(\"cpu\")\n", @@ -279,6 +219,26 @@ " dataset=full_dataset)" ] }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "False" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "config.download" + ] + }, { "cell_type": "code", "execution_count": 9, @@ -286,7 +246,7 @@ "outputs": [], "source": [ "import copy\n", - "full_dataset_camelyon17 = copy.deepcopy(full_dataset)\n" + "full_dataset_camelyon17 = copy.deepcopy(full_dataset)" ] }, { @@ -311,7 +271,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -329,7 +289,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -338,7 +298,7 @@ "1" ] }, - "execution_count": 12, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -487,18 +447,9 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "57.8772239685\n", - "66.8270189762\n" - ] - } - ], + "outputs": [], "source": [ "import numpy as np, pandas as pd, os, time, torch, torchvision\n", "data_dir = '/oak/stanford/groups/akundaje/abalsubr/DREAM/wilds/codalab_archive/'\n", diff --git a/sandbox_data.ipynb b/sandbox_data.ipynb index a348d6ea..15e35cc9 100644 --- a/sandbox_data.ipynb +++ b/sandbox_data.ipynb @@ -9,7 +9,7 @@ " - configs\n", " - [x] supported.py\n", " - [x] model.py\n", - " - [ ] datasets.py\n", + " - [x] datasets.py\n", " - models\n", " - [x] CNN_genome.py\n", " - train.py\n", @@ -21,7 +21,7 @@ " - [x] all_metrics.py\n", " - data_loaders.py\n", " - grouper.py\n", - " - [ ] utils.py ( threshold_at_recall() )" + " - [x] utils.py ( threshold_at_recall() )" ] }, { @@ -30,14 +30,12 @@ "source": [ "# TODOs\n", "\n", - "- [ ] change evaluation/validation metric\n", - " - \n", - "\n", - "- [ ] Citation/license for wilds/datasets/encodetfbs_dataset.py\n", - "\n", + "- change evaluation/validation metric\n", + " - [ ] examples/configs/datasets.py\n", + "- Citation/license for wilds/datasets/encodetfbs_dataset.py\n", "- (optional) change sequence length of model\n", - " - examples/configs/model.py\n", - " - examples/models/CNN_genome.py" + " - [ ] examples/configs/model.py\n", + " - [ ] examples/models/CNN_genome.py" ] }, { diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index 23f0b1d7..56f8c2f9 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -62,13 +62,14 @@ def __init__(self, root_dir, download, split_scheme): self._dnase_allcelltypes[ct][chrom] = dnase_npz_contents[chrom] # Read in metadata dataframe from training+validation data - train_chr = pd.read_csv(os.path.join(self._data_dir, 'labels/{}.train.labels.tsv.gz'.format(self._transcription_factor)), sep='\t') - val_chr = pd.read_csv(os.path.join(self._data_dir, 'labels/{}.val.labels.tsv.gz'.format(self._transcription_factor)), sep='\t') - training_df = train_chr[np.isin(train_chr['chr'], self._tr_chrs)] - val_df = val_chr[np.isin(val_chr['chr'], self._te_chrs)] + train_regions_labeled = pd.read_csv(os.path.join(self._data_dir, 'labels/{}.train.labels.tsv.gz'.format(self._transcription_factor)), sep='\t') + val_regions_labeled = pd.read_csv(os.path.join(self._data_dir, 'labels/{}.val.labels.tsv.gz'.format(self._transcription_factor)), sep='\t') + training_df = train_regions_labeled[np.isin(train_regions_labeled['chr'], self._tr_chrs)] + val_df = val_regions_labeled[np.isin(val_regions_labeled['chr'], self._te_chrs)] all_df = pd.concat([training_df, val_df]) - # Filter by start/stop coordinate if needed + # Filter by start/stop coordinate if needed + # (TODO: remove for final version) filter_msk = all_df['start'] >= 0 filter_msk = all_df['start']%1000 == 0 all_df = all_df[filter_msk] @@ -111,18 +112,18 @@ def __init__(self, root_dir, download, split_scheme): 'test': 'Test', 'val-ood': 'Validation (OOD)', } - train_chr_mask = np.isin(self._metadata_df['chr'], self._tr_chrs) - val_chr_mask = np.isin(self._metadata_df['chr'], self._te_chrs) + train_regions_mask = np.isin(self._metadata_df['chr'], self._tr_chrs) + val_regions_mask = np.isin(self._metadata_df['chr'], self._te_chrs) train_celltype_mask = np.isin(self._metadata_df['celltype'], self._train_celltypes) val_celltype_mask = np.isin(self._metadata_df['celltype'], self._val_celltype) test_celltype_mask = np.isin(self._metadata_df['celltype'], self._test_celltype) split_array = -1*np.ones(self._metadata_df.shape[0]).astype(int) - split_array[np.logical_and(train_chr_mask, train_celltype_mask)] = self._split_dict['train'] - split_array[np.logical_and(val_chr_mask, test_celltype_mask)] = self._split_dict['test'] + split_array[np.logical_and(train_regions_mask, train_celltype_mask)] = self._split_dict['train'] + split_array[np.logical_and(val_regions_mask, test_celltype_mask)] = self._split_dict['test'] # Validate using test chr, either using a designated validation cell line ('val-ood') or a training cell line ('val-id') - split_array[np.logical_and(val_chr_mask, val_celltype_mask)] = self._split_dict['val-ood'] - split_array[np.logical_and(val_chr_mask, train_celltype_mask)] = self._split_dict['val-id'] + split_array[np.logical_and(val_regions_mask, val_celltype_mask)] = self._split_dict['val-ood'] + split_array[np.logical_and(val_regions_mask, train_celltype_mask)] = self._split_dict['val-id'] if self._split_scheme=='standard': self._metadata_df['split'] = split_array self._split_array = split_array From a7e6a5eebfe7a7b0606c3c8746cddc73a6389889 Mon Sep 17 00:00:00 2001 From: aikanor Date: Wed, 10 Feb 2021 04:03:51 -0800 Subject: [PATCH 020/244] integration 14/ --- examples/sbox_run_expt.ipynb | 86 ++++++++++++++++++++++++---- sandbox_data.ipynb | 1 + wilds/datasets/encodetfbs_dataset.py | 2 +- 3 files changed, 77 insertions(+), 12 deletions(-) diff --git a/examples/sbox_run_expt.ipynb b/examples/sbox_run_expt.ipynb index f6c6246e..8684df87 100644 --- a/examples/sbox_run_expt.ipynb +++ b/examples/sbox_run_expt.ipynb @@ -42,7 +42,7 @@ { "data": { "text/plain": [ - "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" + "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" ] }, "execution_count": 2, @@ -145,7 +145,7 @@ "outputs": [], "source": [ "argstr_camelyon = \"--dataset camelyon17 --algorithm ERM --root_dir data\"\n", - "argstr_encode = \"--dataset encode-tfbs --algorithm ERM --root_dir data\"" + "argstr_encode = \"--dataset encode-tfbs --algorithm ERM --root_dir data\"\n" ] }, { @@ -162,19 +162,83 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "config = config_camelyon\n", - "#config = config_encode" + "config = config_encode\n" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Dataset: encode-tfbs\n", + "Algorithm: ERM\n", + "Root dir: data\n", + "Split scheme: official\n", + "Dataset kwargs: {}\n", + "Download: False\n", + "Frac: 1.0\n", + "Loader kwargs: {'num_workers': 4, 'pin_memory': True}\n", + "Train loader: standard\n", + "Uniform over groups: False\n", + "Distinct groups: None\n", + "N groups per batch: 2\n", + "Batch size: 128\n", + "Eval loader: standard\n", + "Model: beagle\n", + "Model kwargs: {'pretrained': False}\n", + "Train transform: None\n", + "Eval transform: None\n", + "Target resolution: None\n", + "Resize scale: None\n", + "Max token length: None\n", + "Loss function: cross_entropy\n", + "Groupby fields: ['celltype']\n", + "Group dro step size: None\n", + "Coral penalty weight: None\n", + "Irm lambda: None\n", + "Irm penalty anneal iters: None\n", + "Algo log metric: accuracy\n", + "Val metric: acc_avg\n", + "Val metric decreasing: False\n", + "N epochs: 1\n", + "Optimizer: Adam\n", + "Lr: 0.001\n", + "Weight decay: 0.01\n", + "Max grad norm: None\n", + "Optimizer kwargs: {'momentum': 0.9}\n", + "Scheduler: None\n", + "Scheduler kwargs: {}\n", + "Scheduler metric split: val\n", + "Scheduler metric name: None\n", + "Evaluate all splits: True\n", + "Eval splits: []\n", + "Eval only: False\n", + "Eval epoch: None\n", + "Device: cuda:0\n", + "Seed: 0\n", + "Log dir: ./logs\n", + "Log every: 50\n", + "Save step: None\n", + "Save best: True\n", + "Save last: True\n", + "No group logging: False\n", + "Use wandb: False\n", + "Progress bar: False\n", + "Resume: False\n", + "\n", + "asdf data/encode-tfbs_v1.0 data/encode-tfbs_v1.0/RELEASE_v1.0.txt 1 0 data/encode-tfbs_v1.0/RELEASE_v1.0.txt\n" + ] + } + ], "source": [ "# set device\n", "config.device = torch.device(\"cuda:\" + str(config.device)) if torch.cuda.is_available() else torch.device(\"cpu\")\n", @@ -195,7 +259,7 @@ "logger = Logger(os.path.join(config.log_dir, 'log.txt'), mode)\n", "\n", "# Record config\n", - "# log_config(config, logger)\n", + "log_config(config, logger)\n", "\n", "# Set random seed\n", "set_seed(config.seed)\n", @@ -221,22 +285,22 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "False" + "device(type='cuda', index=0)" ] }, - "execution_count": 7, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "config.download" + "config.device" ] }, { diff --git a/sandbox_data.ipynb b/sandbox_data.ipynb index 15e35cc9..c465e0ab 100644 --- a/sandbox_data.ipynb +++ b/sandbox_data.ipynb @@ -32,6 +32,7 @@ "\n", "- change evaluation/validation metric\n", " - [ ] examples/configs/datasets.py\n", + "- Add `RELEASE_v1.0.txt` to codalab archive\n", "- Citation/license for wilds/datasets/encodetfbs_dataset.py\n", "- (optional) change sequence length of model\n", " - [ ] examples/configs/model.py\n", diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index 56f8c2f9..dc76a366 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -24,7 +24,7 @@ class EncodeTFBSDataset(WILDSDataset): https://www.synapse.org/#!Synapse:syn6131484 """ - def __init__(self, root_dir, download, split_scheme): + def __init__(self, root_dir='data', download=False, split_scheme='official'): self._dataset_name = 'encode-tfbs' self._version = '1.0' self._download_url = 'https://worksheets.codalab.org/rest/bundles/0x8b3255e21e164cd98d3aeec09cd0bc26/contents/blob/' From b2d4668a52ad89f2935993526197e866cf1ce26d Mon Sep 17 00:00:00 2001 From: aikanor Date: Wed, 10 Feb 2021 08:56:19 -0800 Subject: [PATCH 021/244] integration 14/ --- examples/sbox_run_expt.ipynb | 90 ++++++++++++++++++++++++++-- wilds/datasets/encodetfbs_dataset.py | 8 ++- 2 files changed, 91 insertions(+), 7 deletions(-) diff --git a/examples/sbox_run_expt.ipynb b/examples/sbox_run_expt.ipynb index 8684df87..b5581ee8 100644 --- a/examples/sbox_run_expt.ipynb +++ b/examples/sbox_run_expt.ipynb @@ -42,7 +42,7 @@ { "data": { "text/plain": [ - "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" + "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" ] }, "execution_count": 2, @@ -172,7 +172,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -235,7 +235,87 @@ "Progress bar: False\n", "Resume: False\n", "\n", - "asdf data/encode-tfbs_v1.0 data/encode-tfbs_v1.0/RELEASE_v1.0.txt 1 0 data/encode-tfbs_v1.0/RELEASE_v1.0.txt\n" + "chr1 3.7313027381896973\n", + "chr2 7.379143953323364\n", + "chr3 10.349414587020874\n", + "chr4 13.229614734649658\n", + "chr5 15.956977605819702\n", + "chr6 18.53829526901245\n", + "chr7 20.938751935958862\n", + "chr8 23.146727561950684\n", + "chr9 25.268455028533936\n", + "chr10 27.31314730644226\n", + "chr11 29.348772525787354\n", + "chr12 31.367613554000854\n", + "chr13 33.09688472747803\n", + "chr14 34.706626892089844\n", + "chr15 36.250834941864014\n", + "chr16 37.6129195690155\n", + "chr17 38.84043622016907\n", + "chr18 40.02457594871521\n", + "chr19 40.91551661491394\n", + "chr20 41.871009349823\n", + "chr21 42.5933620929718\n", + "chr22 43.36084580421448\n", + "chrX 45.713836431503296\n", + "H1-hESC 65.68310618400574\n", + "HCT116 85.14427304267883\n", + "HeLa-S3 105.4215178489685\n", + "HepG2 125.414067029953\n", + "K562 145.60503768920898\n", + "A549 165.76467108726501\n", + "GM12878 185.90100407600403\n", + " chr start stop A549 GM12878 H1-hESC HCT116 HeLa-S3 \\\n", + "0 chr10 600 800 U U U U U \n", + "1 chr10 650 850 U U U U U \n", + "2 chr10 700 900 U U U U U \n", + "3 chr10 750 950 U U U U U \n", + "4 chr10 800 1000 U U U U U \n", + "... ... ... ... ... ... ... ... ... \n", + "51676731 chrX 155269750 155269950 U U U U U \n", + "51676732 chrX 155269800 155270000 U U U U U \n", + "51676733 chrX 155269850 155270050 U U U U U \n", + "51676734 chrX 155269900 155270100 U U U U U \n", + "51676735 chrX 155269950 155270150 U U U U U \n", + "\n", + " HepG2 K562 \n", + "0 U U \n", + "1 U U \n", + "2 U U \n", + "3 U U \n", + "4 U U \n", + "... ... ... \n", + "51676731 U U \n", + "51676732 U U \n", + "51676733 U U \n", + "51676734 U U \n", + "51676735 U U \n", + "\n", + "[51676736 rows x 10 columns] 252.31056427955627\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/users/abalsubr/wilds/wilds/datasets/encodetfbs_dataset.py:85: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " tc_chr['celltype'] = ct\n" + ] + }, + { + "ename": "ValueError", + "evalue": "Split scheme official not recognized", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0;31m# Data\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 26\u001b[0;31m full_dataset = supported.datasets[config.dataset](\n\u001b[0m\u001b[1;32m 27\u001b[0m \u001b[0mroot_dir\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mroot_dir\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0mdownload\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdownload\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/wilds/wilds/datasets/encodetfbs_dataset.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, root_dir, download, split_scheme)\u001b[0m\n\u001b[1;32m 133\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_split_array\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msplit_array\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 134\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 135\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'Split scheme {self._split_scheme} not recognized'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 136\u001b[0m self._eval_grouper = CombinatorialGrouper(\n\u001b[1;32m 137\u001b[0m \u001b[0mdataset\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mValueError\u001b[0m: Split scheme official not recognized" ] } ], @@ -285,7 +365,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -294,7 +374,7 @@ "device(type='cuda', index=0)" ] }, - "execution_count": 9, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index dc76a366..327ffb3e 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -1,4 +1,4 @@ -import os +import os, time import torch import pandas as pd import numpy as np @@ -25,6 +25,7 @@ class EncodeTFBSDataset(WILDSDataset): """ def __init__(self, root_dir='data', download=False, split_scheme='official'): + itime = time.time() self._dataset_name = 'encode-tfbs' self._version = '1.0' self._download_url = 'https://worksheets.codalab.org/rest/bundles/0x8b3255e21e164cd98d3aeec09cd0bc26/contents/blob/' @@ -52,6 +53,7 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): self._seq_bp = {} for chrom in seq_arr: self._seq_bp[chrom] = seq_arr[chrom] + print(chrom, time.time() - itime) self._dnase_allcelltypes = {} for ct in self._all_celltypes: @@ -60,6 +62,7 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): self._dnase_allcelltypes[ct] = {} for chrom in self._seq_bp: self._dnase_allcelltypes[ct][chrom] = dnase_npz_contents[chrom] + print(ct, time.time() - itime) # Read in metadata dataframe from training+validation data train_regions_labeled = pd.read_csv(os.path.join(self._data_dir, 'labels/{}.train.labels.tsv.gz'.format(self._transcription_factor)), sep='\t') @@ -67,6 +70,7 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): training_df = train_regions_labeled[np.isin(train_regions_labeled['chr'], self._tr_chrs)] val_df = val_regions_labeled[np.isin(val_regions_labeled['chr'], self._te_chrs)] all_df = pd.concat([training_df, val_df]) + print(train_regions_labeled, time.time() - itime) # Filter by start/stop coordinate if needed # (TODO: remove for final version) @@ -150,7 +154,7 @@ def get_input(self, idx): interval_end = this_metadata['stop'] + flank_size dnase_this = _dnase_allcelltypes[this_metadata['celltype']][this_metadata['chr']][interval_start:interval_end] seq_this = _seq_bp[this_metadata['chr']][interval_start:interval_end] - return np.column_stack([seq_this, dnase_this]) + return torch.tensor(np.column_stack([seq_this, dnase_this])) def eval(self, y_pred, y_true, metadata): return self.standard_group_eval( From c348587a208d85d53fc12e85d795469991f3f4af Mon Sep 17 00:00:00 2001 From: aikanor Date: Wed, 10 Feb 2021 19:45:41 -0800 Subject: [PATCH 022/244] integration 15/ (refactor encodetfbs_dataset) --- examples/sbox_run_expt.ipynb | 2133 +++++++++++++++++++------- wilds/datasets/encodetfbs_dataset.py | 102 +- 2 files changed, 1664 insertions(+), 571 deletions(-) diff --git a/examples/sbox_run_expt.ipynb b/examples/sbox_run_expt.ipynb index b5581ee8..56a9f6a2 100644 --- a/examples/sbox_run_expt.ipynb +++ b/examples/sbox_run_expt.ipynb @@ -4,19 +4,21 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# run_expt.py contents" + "# run_expt.py contents\n", + "\n", + "## 1) Preamble" ] }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "import os, csv\n", "import time\n", "import argparse\n", - "import pandas as pd\n", + "import numpy as np, pandas as pd\n", "import torch\n", "import torch.nn as nn\n", "import torchvision\n", @@ -42,7 +44,7 @@ { "data": { "text/plain": [ - "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" + "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" ] }, "execution_count": 2, @@ -145,41 +147,32 @@ "outputs": [], "source": [ "argstr_camelyon = \"--dataset camelyon17 --algorithm ERM --root_dir data\"\n", - "argstr_encode = \"--dataset encode-tfbs --algorithm ERM --root_dir data\"\n" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ "config_camelyon = parser.parse_args(argstr_camelyon.split())\n", - "config_encode = parser.parse_args(argstr_encode.split())\n", "config_camelyon = populate_defaults(config_camelyon)\n", - "config_encode = populate_defaults(config_encode)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ + "\n", + "argstr_encode = \"--dataset encode-tfbs --algorithm ERM --root_dir data\"\n", + "config_encode = parser.parse_args(argstr_encode.split())\n", + "config_encode = populate_defaults(config_encode)\n", + "\n", "config = config_camelyon\n", - "config = config_encode\n" + "#config = config_encode" ] }, { "cell_type": "code", "execution_count": 6, - "metadata": {}, + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + } + }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Dataset: encode-tfbs\n", + "Dataset: camelyon17\n", "Algorithm: ERM\n", "Root dir: data\n", "Split scheme: official\n", @@ -191,26 +184,26 @@ "Uniform over groups: False\n", "Distinct groups: None\n", "N groups per batch: 2\n", - "Batch size: 128\n", + "Batch size: 32\n", "Eval loader: standard\n", - "Model: beagle\n", + "Model: densenet121\n", "Model kwargs: {'pretrained': False}\n", - "Train transform: None\n", - "Eval transform: None\n", - "Target resolution: None\n", + "Train transform: image_base\n", + "Eval transform: image_base\n", + "Target resolution: (224, 224)\n", "Resize scale: None\n", "Max token length: None\n", "Loss function: cross_entropy\n", - "Groupby fields: ['celltype']\n", + "Groupby fields: ['hospital']\n", "Group dro step size: None\n", - "Coral penalty weight: None\n", - "Irm lambda: None\n", + "Coral penalty weight: 0.1\n", + "Irm lambda: 1.0\n", "Irm penalty anneal iters: None\n", "Algo log metric: accuracy\n", "Val metric: acc_avg\n", "Val metric decreasing: False\n", - "N epochs: 1\n", - "Optimizer: Adam\n", + "N epochs: 5\n", + "Optimizer: SGD\n", "Lr: 0.001\n", "Weight decay: 0.01\n", "Max grad norm: None\n", @@ -234,88 +227,7 @@ "Use wandb: False\n", "Progress bar: False\n", "Resume: False\n", - "\n", - "chr1 3.7313027381896973\n", - "chr2 7.379143953323364\n", - "chr3 10.349414587020874\n", - "chr4 13.229614734649658\n", - "chr5 15.956977605819702\n", - "chr6 18.53829526901245\n", - "chr7 20.938751935958862\n", - "chr8 23.146727561950684\n", - "chr9 25.268455028533936\n", - "chr10 27.31314730644226\n", - "chr11 29.348772525787354\n", - "chr12 31.367613554000854\n", - "chr13 33.09688472747803\n", - "chr14 34.706626892089844\n", - "chr15 36.250834941864014\n", - "chr16 37.6129195690155\n", - "chr17 38.84043622016907\n", - "chr18 40.02457594871521\n", - "chr19 40.91551661491394\n", - "chr20 41.871009349823\n", - "chr21 42.5933620929718\n", - "chr22 43.36084580421448\n", - "chrX 45.713836431503296\n", - "H1-hESC 65.68310618400574\n", - "HCT116 85.14427304267883\n", - "HeLa-S3 105.4215178489685\n", - "HepG2 125.414067029953\n", - "K562 145.60503768920898\n", - "A549 165.76467108726501\n", - "GM12878 185.90100407600403\n", - " chr start stop A549 GM12878 H1-hESC HCT116 HeLa-S3 \\\n", - "0 chr10 600 800 U U U U U \n", - "1 chr10 650 850 U U U U U \n", - "2 chr10 700 900 U U U U U \n", - "3 chr10 750 950 U U U U U \n", - "4 chr10 800 1000 U U U U U \n", - "... ... ... ... ... ... ... ... ... \n", - "51676731 chrX 155269750 155269950 U U U U U \n", - "51676732 chrX 155269800 155270000 U U U U U \n", - "51676733 chrX 155269850 155270050 U U U U U \n", - "51676734 chrX 155269900 155270100 U U U U U \n", - "51676735 chrX 155269950 155270150 U U U U U \n", - "\n", - " HepG2 K562 \n", - "0 U U \n", - "1 U U \n", - "2 U U \n", - "3 U U \n", - "4 U U \n", - "... ... ... \n", - "51676731 U U \n", - "51676732 U U \n", - "51676733 U U \n", - "51676734 U U \n", - "51676735 U U \n", - "\n", - "[51676736 rows x 10 columns] 252.31056427955627\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/users/abalsubr/wilds/wilds/datasets/encodetfbs_dataset.py:85: SettingWithCopyWarning: \n", - "A value is trying to be set on a copy of a slice from a DataFrame.\n", - "Try using .loc[row_indexer,col_indexer] = value instead\n", - "\n", - "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", - " tc_chr['celltype'] = ct\n" - ] - }, - { - "ename": "ValueError", - "evalue": "Split scheme official not recognized", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0;31m# Data\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 26\u001b[0;31m full_dataset = supported.datasets[config.dataset](\n\u001b[0m\u001b[1;32m 27\u001b[0m \u001b[0mroot_dir\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mroot_dir\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0mdownload\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdownload\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/wilds/wilds/datasets/encodetfbs_dataset.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, root_dir, download, split_scheme)\u001b[0m\n\u001b[1;32m 133\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_split_array\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msplit_array\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 134\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 135\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'Split scheme {self._split_scheme} not recognized'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 136\u001b[0m self._eval_grouper = CombinatorialGrouper(\n\u001b[1;32m 137\u001b[0m \u001b[0mdataset\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mValueError\u001b[0m: Split scheme official not recognized" + "\n" ] } ], @@ -364,93 +276,1549 @@ ] }, { - "cell_type": "code", - "execution_count": 7, + "cell_type": "markdown", "metadata": {}, + "source": [ + "## 2) Initialize dataset object" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + } + }, "outputs": [ { - "data": { - "text/plain": [ - "device(type='cuda', index=0)" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "chr2 3.817250967025757\n", + "chr9 6.033524990081787\n", + "chr11 8.150986433029175\n", + "chr1 12.036555290222168\n", + "chr8 14.306443929672241\n", + "chr21 15.043241739273071\n", + "H1-hESC 21.61008930206299\n", + "HCT116 28.000329971313477\n", + "HeLa-S3 34.6184778213501\n", + "HepG2 41.089255809783936\n", + "K562 47.70136523246765\n", + "A549 54.22390341758728\n", + "GM12878 60.65142226219177\n", + " chr start stop A549 GM12878 H1-hESC HCT116 HeLa-S3 \\\n", + "0 chr10 600 800 U U U U U \n", + "1 chr10 650 850 U U U U U \n", + "2 chr10 700 900 U U U U U \n", + "3 chr10 750 950 U U U U U \n", + "4 chr10 800 1000 U U U U U \n", + "... ... ... ... ... ... ... ... ... \n", + "51676731 chrX 155269750 155269950 U U U U U \n", + "51676732 chrX 155269800 155270000 U U U U U \n", + "51676733 chrX 155269850 155270050 U U U U U \n", + "51676734 chrX 155269900 155270100 U U U U U \n", + "51676735 chrX 155269950 155270150 U U U U U \n", + "\n", + " HepG2 K562 \n", + "0 U U \n", + "1 U U \n", + "2 U U \n", + "3 U U \n", + "4 U U \n", + "... ... ... \n", + "51676731 U U \n", + "51676732 U U \n", + "51676733 U U \n", + "51676734 U U \n", + "51676735 U U \n", + "\n", + "[51676736 rows x 10 columns] 130.07371044158936\n" + ] } ], "source": [ - "config.device" + "import os, time\n", + "import torch\n", + "import pandas as pd\n", + "import numpy as np\n", + "from wilds.datasets.wilds_dataset import WILDSDataset\n", + "from wilds.common.grouper import CombinatorialGrouper\n", + "from wilds.common.metrics.all_metrics import Accuracy\n", + "\n", + "root_dir='data'\n", + "download=False\n", + "split_scheme='official'\n", + "\n", + "itime = time.time()\n", + "_dataset_name = 'encode-tfbs'\n", + "_version = '1.0'\n", + "_download_url = 'https://worksheets.codalab.org/rest/bundles/0x8b3255e21e164cd98d3aeec09cd0bc26/contents/blob/'\n", + "_data_dir = 'data/encode-tfbs_v1.0'\n", + "_y_size = 1\n", + "_n_classes = 2\n", + "\n", + "# _train_chroms = ['chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX']\n", + "_train_chroms = ['chr2', 'chr9', 'chr11']\n", + "_test_chroms = ['chr1', 'chr8', 'chr21']\n", + "_transcription_factor = 'MAX'\n", + "_train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']\n", + "_val_celltype = ['A549']\n", + "_test_celltype = ['GM12878']\n", + "_all_chroms = _train_chroms + _test_chroms\n", + "_all_celltypes = _train_celltypes + _val_celltype + _test_celltype\n", + "\n", + "_metadata_map = {}\n", + "_metadata_map['chr'] = _all_chroms #['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX']\n", + "_metadata_map['celltype'] = _all_celltypes\n", + "\n", + "# Get the splits\n", + "if split_scheme=='official':\n", + " split_scheme = 'standard'\n", + "\n", + "_split_scheme = split_scheme\n", + "_split_dict = {\n", + " 'train': 0,\n", + " 'id_val': 1,\n", + " 'test': 2,\n", + " 'val': 3\n", + "}\n", + "_split_names = {\n", + " 'train': 'Train',\n", + " 'id_val': 'Validation (ID)',\n", + " 'test': 'Test',\n", + " 'val': 'Validation (OOD)',\n", + "}\n", + "\n", + "# Load sequence and DNase features\n", + "sequence_filename = os.path.join(_data_dir, 'sequence.npz')\n", + "seq_arr = np.load(sequence_filename)\n", + "_seq_bp = {}\n", + "for chrom in _all_chroms: #seq_arr:\n", + " _seq_bp[chrom] = seq_arr[chrom]\n", + " print(chrom, time.time() - itime)\n", + "\n", + "_dnase_allcelltypes = {}\n", + "for ct in _all_celltypes:\n", + " dnase_filename = os.path.join(_data_dir, '{}_dnase.npz'.format(ct))\n", + " dnase_npz_contents = np.load(dnase_filename)\n", + " _dnase_allcelltypes[ct] = {}\n", + " for chrom in _all_chroms: #_seq_bp:\n", + " _dnase_allcelltypes[ct][chrom] = dnase_npz_contents[chrom]\n", + " print(ct, time.time() - itime)\n", + "\n", + "# Read in metadata dataframe from training+validation data\n", + "train_regions_labeled = pd.read_csv(os.path.join(_data_dir, 'labels/{}.train.labels.tsv.gz'.format(_transcription_factor)), sep='\\t')\n", + "val_regions_labeled = pd.read_csv(os.path.join(_data_dir, 'labels/{}.val.labels.tsv.gz'.format(_transcription_factor)), sep='\\t')\n", + "training_df = train_regions_labeled[np.isin(train_regions_labeled['chr'], _train_chroms)]\n", + "val_df = val_regions_labeled[np.isin(val_regions_labeled['chr'], _test_chroms)]\n", + "all_df = pd.concat([training_df, val_df])\n", + "\n", + "# Filter by start/stop coordinate if needed (TODO: remove for final version)\n", + "filter_msk = all_df['start'] >= 0\n", + "filter_msk = all_df['start']%1000 == 0\n", + "all_df = all_df[filter_msk]\n", + "\n", + "pd_list = []\n", + "for ct in _all_celltypes:\n", + " tc_chr = all_df[['chr', 'start', 'stop', ct]]\n", + " tc_chr.columns = ['chr', 'start', 'stop', 'y']\n", + " tc_chr.insert(len(tc_chr.columns), 'celltype', ct)\n", + " pd_list.append(tc_chr)\n", + "metadata_df = pd.concat(pd_list)" + ] + }, + { + "cell_type": "code", + "execution_count": 131, + "metadata": {}, + "outputs": [], + "source": [ + "# Get the y values, and remove ambiguous labels by default.\n", + "y_array = metadata_df['y'].replace({'U': 0, 'B': 1, 'A': -1}).values\n", + "non_ambig_mask = (y_array != -1)\n", + "metadata_df['y'] = y_array\n", + "_metadata_df = metadata_df[non_ambig_mask]\n", + "\n", + "train_regions_mask = np.isin(_metadata_df['chr'], _train_chroms)\n", + "val_regions_mask = np.isin(_metadata_df['chr'], _test_chroms)\n", + "train_celltype_mask = np.isin(_metadata_df['celltype'], _train_celltypes)\n", + "val_celltype_mask = np.isin(_metadata_df['celltype'], _val_celltype)\n", + "test_celltype_mask = np.isin(_metadata_df['celltype'], _test_celltype)\n", + "\n", + "split_array = -1*np.ones(_metadata_df.shape[0]).astype(int)\n", + "split_array[np.logical_and(train_regions_mask, train_celltype_mask)] = _split_dict['train']\n", + "split_array[np.logical_and(val_regions_mask, test_celltype_mask)] = _split_dict['test']\n", + "# Validate using test chr, either using a designated validation cell line ('val') or a training cell line ('id_val')\n", + "split_array[np.logical_and(val_regions_mask, val_celltype_mask)] = _split_dict['val']\n", + "split_array[np.logical_and(val_regions_mask, train_celltype_mask)] = _split_dict['id_val']\n", + "\n", + "if _split_scheme=='standard':\n", + " _metadata_df.insert(len(_metadata_df.columns), 'split', split_array)\n", + "else:\n", + " raise ValueError(f'Split scheme {_split_scheme} not recognized')\n", + "\n", + "_metadata_df = _metadata_df[_metadata_df['split'] != -1]\n", + "_split_array = _metadata_df['split'].values\n", + "\n", + "chr_ints = _metadata_df['chr'].replace(dict( [(y, x) for x, y in enumerate(_metadata_map['chr'])] )).values\n", + "celltype_ints = _metadata_df['celltype'].replace(dict( [(y, x) for x, y in enumerate(_metadata_map['celltype'])] )).values\n", + "_y_array = torch.LongTensor(np.array(_metadata_df['y']))\n", + "\n", + "_metadata_array = torch.stack(\n", + " (torch.LongTensor(chr_ints), \n", + " torch.LongTensor(celltype_ints), \n", + " _y_array),\n", + " dim=1)\n", + "_metadata_fields = ['chr', 'celltype', 'y']\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Initialize dataset object" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 138, "metadata": {}, "outputs": [], "source": [ - "import copy\n", - "full_dataset_camelyon17 = copy.deepcopy(full_dataset)" + "import os, time\n", + "import torch\n", + "import pandas as pd\n", + "import numpy as np\n", + "from wilds.datasets.wilds_dataset import WILDSDataset\n", + "from wilds.common.grouper import CombinatorialGrouper\n", + "from wilds.common.metrics.all_metrics import Accuracy\n", + "\n", + "class EncodeTFBSDataset(WILDSDataset):\n", + " \"\"\"\n", + " ENCODE-DREAM-wilds dataset of transcription factor binding sites. \n", + " This is a subset of the dataset from the ENCODE-DREAM in vivo Transcription Factor Binding Site Prediction Challenge. \n", + " \n", + " Input (x):\n", + " 1000-base-pair regions of sequence with a quantified chromatin accessibility readout.\n", + "\n", + " Label (y):\n", + " y is binary. It is 1 if the central 200bp region is bound by the transcription factor MAX, and 0 otherwise.\n", + "\n", + " Metadata:\n", + " Each sequence is annotated with the celltype of origin (a string) and the chromosome of origin (a string).\n", + " \n", + " Website:\n", + " https://www.synapse.org/#!Synapse:syn6131484\n", + " \"\"\"\n", + "\n", + " def __init__(self, root_dir='data', download=False, split_scheme='official'):\n", + " itime = time.time()\n", + " self._dataset_name = 'encode-tfbs'\n", + " self._version = '1.0'\n", + " self._download_url = 'https://worksheets.codalab.org/rest/bundles/0x8b3255e21e164cd98d3aeec09cd0bc26/contents/blob/'\n", + " self._data_dir = self.initialize_data_dir(root_dir, download)\n", + " self._y_size = 1\n", + " self._n_classes = 2\n", + " \n", + " # self._train_chroms = ['chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX']\n", + " self._train_chroms = ['chr2', 'chr9', 'chr11']\n", + " self._test_chroms = ['chr1', 'chr8', 'chr21']\n", + " self._transcription_factor = 'MAX'\n", + " self._train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']\n", + " self._val_celltype = ['A549']\n", + " self._test_celltype = ['GM12878']\n", + " self._all_chroms = self._train_chroms + self._test_chroms\n", + " self._all_celltypes = self._train_celltypes + self._val_celltype + self._test_celltype\n", + " \n", + " self._metadata_map = {}\n", + " self._metadata_map['chr'] = self._all_chroms #['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX']\n", + " self._metadata_map['celltype'] = self._all_celltypes\n", + " \n", + " # Get the splits\n", + " if split_scheme=='official':\n", + " split_scheme = 'standard'\n", + " \n", + " self._split_scheme = split_scheme\n", + " self._split_dict = {\n", + " 'train': 0,\n", + " 'id_val': 1,\n", + " 'test': 2,\n", + " 'val': 3\n", + " }\n", + " self._split_names = {\n", + " 'train': 'Train',\n", + " 'id_val': 'Validation (ID)',\n", + " 'test': 'Test',\n", + " 'val': 'Validation (OOD)',\n", + " }\n", + " \n", + " # Load sequence and DNase features\n", + " sequence_filename = os.path.join(self._data_dir, 'sequence.npz')\n", + " seq_arr = np.load(sequence_filename)\n", + " self._seq_bp = {}\n", + " for chrom in self._all_chroms: #seq_arr:\n", + " self._seq_bp[chrom] = seq_arr[chrom]\n", + " print(chrom, time.time() - itime)\n", + " \n", + " self._dnase_allcelltypes = {}\n", + " for ct in self._all_celltypes:\n", + " dnase_filename = os.path.join(self._data_dir, '{}_dnase.npz'.format(ct))\n", + " dnase_npz_contents = np.load(dnase_filename)\n", + " self._dnase_allcelltypes[ct] = {}\n", + " for chrom in self._all_chroms: #self._seq_bp:\n", + " self._dnase_allcelltypes[ct][chrom] = dnase_npz_contents[chrom]\n", + " print(ct, time.time() - itime)\n", + " \n", + " # Read in metadata dataframe from training+validation data\n", + " train_regions_labeled = pd.read_csv(os.path.join(self._data_dir, 'labels/{}.train.labels.tsv.gz'.format(self._transcription_factor)), sep='\\t')\n", + " val_regions_labeled = pd.read_csv(os.path.join(self._data_dir, 'labels/{}.val.labels.tsv.gz'.format(self._transcription_factor)), sep='\\t')\n", + " training_df = train_regions_labeled[np.isin(train_regions_labeled['chr'], self._train_chroms)]\n", + " val_df = val_regions_labeled[np.isin(val_regions_labeled['chr'], self._test_chroms)]\n", + " all_df = pd.concat([training_df, val_df])\n", + " \n", + " # Filter by start/stop coordinate if needed (TODO: remove for final version)\n", + " filter_msk = all_df['start'] >= 0\n", + " filter_msk = all_df['start']%1000 == 0\n", + " all_df = all_df[filter_msk]\n", + " \n", + " pd_list = []\n", + " for ct in self._all_celltypes:\n", + " tc_chr = all_df[['chr', 'start', 'stop', ct]]\n", + " tc_chr.columns = ['chr', 'start', 'stop', 'y']\n", + " tc_chr.insert(len(tc_chr.columns), 'celltype', ct)\n", + " pd_list.append(tc_chr)\n", + " metadata_df = pd.concat(pd_list)\n", + " \n", + " # Get the y values, and remove ambiguous labels by default.\n", + " y_array = metadata_df['y'].replace({'U': 0, 'B': 1, 'A': -1}).values\n", + " non_ambig_mask = (y_array != -1)\n", + " metadata_df['y'] = y_array\n", + " self._metadata_df = metadata_df[non_ambig_mask]\n", + " \n", + " train_regions_mask = np.isin(self._metadata_df['chr'], self._train_chroms)\n", + " val_regions_mask = np.isin(self._metadata_df['chr'], self._test_chroms)\n", + " train_celltype_mask = np.isin(self._metadata_df['celltype'], self._train_celltypes)\n", + " val_celltype_mask = np.isin(self._metadata_df['celltype'], self._val_celltype)\n", + " test_celltype_mask = np.isin(self._metadata_df['celltype'], self._test_celltype)\n", + " \n", + " split_array = -1*np.ones(self._metadata_df.shape[0]).astype(int)\n", + " split_array[np.logical_and(train_regions_mask, train_celltype_mask)] = self._split_dict['train']\n", + " split_array[np.logical_and(val_regions_mask, test_celltype_mask)] = self._split_dict['test']\n", + " # Validate using test chr, either using a designated validation cell line ('val') or a training cell line ('id_val')\n", + " split_array[np.logical_and(val_regions_mask, val_celltype_mask)] = self._split_dict['val']\n", + " split_array[np.logical_and(val_regions_mask, train_celltype_mask)] = self._split_dict['id_val']\n", + " \n", + " if self._split_scheme=='standard':\n", + " self._metadata_df.insert(len(self._metadata_df.columns), 'split', split_array)\n", + " else:\n", + " raise ValueError(f'Split scheme {self._split_scheme} not recognized')\n", + " \n", + " self._metadata_df = self._metadata_df[self._metadata_df['split'] != -1]\n", + " self._split_array = self._metadata_df['split'].values\n", + " \n", + " chr_ints = self._metadata_df['chr'].replace(dict( [(y, x) for x, y in enumerate(self._metadata_map['chr'])] )).values\n", + " celltype_ints = self._metadata_df['celltype'].replace(dict( [(y, x) for x, y in enumerate(self._metadata_map['celltype'])] )).values\n", + " self._y_array = torch.LongTensor(np.array(self._metadata_df['y']))\n", + " \n", + " self._metadata_array = torch.stack(\n", + " (torch.LongTensor(chr_ints), \n", + " torch.LongTensor(celltype_ints), \n", + " self._y_array),\n", + " dim=1)\n", + " self._metadata_fields = ['chr', 'celltype', 'y']\n", + " \n", + " self._eval_grouper = CombinatorialGrouper(\n", + " dataset=self,\n", + " groupby_fields=['celltype'])\n", + " \n", + " self._metric = Accuracy()\n", + " \n", + " super().__init__(root_dir, download, split_scheme)\n", + "\n", + " def get_input(self, idx):\n", + " \"\"\"\n", + " Returns x for a given idx.\n", + " Computes this from: \n", + " (1) sequence features in self._seq_bp\n", + " (2) DNase features in self._dnase_allcelltypes\n", + " (3) Metadata for the index (location along the genome with 200bp window width)\n", + " \"\"\"\n", + " this_metadata = self._metadata_df.iloc[idx, :]\n", + " flank_size = 400\n", + " interval_start = this_metadata['start'] - flank_size\n", + " interval_end = this_metadata['stop'] + flank_size\n", + " dnase_this = self._dnase_allcelltypes[this_metadata['celltype']][this_metadata['chr']][interval_start:interval_end]\n", + " seq_this = self._seq_bp[this_metadata['chr']][interval_start:interval_end]\n", + " return torch.tensor(np.column_stack([seq_this, dnase_this]))\n", + "\n", + " def eval(self, y_pred, y_true, metadata):\n", + " return self.standard_group_eval(\n", + " self._metric,\n", + " self._eval_grouper,\n", + " y_pred, y_true, metadata)" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 139, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "chr2 3.836134910583496\n", + "chr9 6.034452438354492\n", + "chr11 8.16244888305664\n", + "chr1 12.12421178817749\n", + "chr8 14.44963550567627\n", + "chr21 15.212148189544678\n", + "H1-hESC 21.892271518707275\n", + "HCT116 28.37229895591736\n", + "HeLa-S3 35.18828296661377\n", + "HepG2 41.83891773223877\n", + "K562 48.590251445770264\n", + "A549 55.3311812877655\n", + "GM12878 61.93817687034607\n" + ] + } + ], + "source": [ + "full_dataset_encode = EncodeTFBSDataset(\n", + " root_dir=config.root_dir,\n", + " download=config.download,\n", + " split_scheme=config.split_scheme,\n", + " **config.dataset_kwargs)" + ] + }, + { + "cell_type": "code", + "execution_count": 140, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "" + "<__main__.EncodeTFBSDataset at 0x7fe6b69d33a0>" ] }, - "execution_count": 10, + "execution_count": 140, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "full_dataset_camelyon17" + "full_dataset_encode" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 14, "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "image_base None\n" - ] + "data": { + "text/plain": [ + "(array([0, 1]), array([227977, 227977]))" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" } ], - "source": [ - "supported.datasets[config_encode.dataset]\n", - "print(config_camelyon.train_transform, config_encode.train_transform)" - ] + "source": [] }, { "cell_type": "code", - "execution_count": 11, - "metadata": {}, + "execution_count": 17, + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + } + }, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([1, 1, 1, ..., 0, 0, 0]) torch.Size([455954])\n" + ] + }, { "data": { "text/plain": [ - "1" + "['patches/patient_004_node_4/patch_patient_004_node_4_x_3328_y_21792.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3200_y_22272.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3168_y_22272.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3328_y_21760.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3232_y_22240.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3168_y_22240.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3136_y_22208.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_2656_y_18880.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3136_y_22240.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3296_y_21856.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3296_y_21792.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3360_y_21824.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3360_y_21760.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3296_y_21824.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3328_y_21824.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_2688_y_18912.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3168_y_22176.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_2688_y_18816.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3200_y_22176.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3168_y_22208.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_2688_y_18880.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3296_y_21760.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_2656_y_18848.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3136_y_22272.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3264_y_21856.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3264_y_21824.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_2688_y_18848.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3264_y_21792.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_2656_y_18944.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3200_y_22208.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3232_y_22208.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3200_y_22240.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_2688_y_18944.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3360_y_21792.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_2656_y_18912.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_2656_y_18816.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12960_y_36128.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_35968.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_36512.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_36064.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13312_y_36320.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_35968.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12128_y_34560.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_36384.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_36192.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13120_y_35936.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_36480.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_35680.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_35648.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_36032.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_36416.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12992_y_36160.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12160_y_34752.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_36288.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_36160.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_35744.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12672_y_35840.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_36320.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12896_y_35648.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_35904.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12864_y_36512.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13120_y_35904.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_35872.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_36192.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_36224.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_36416.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_35872.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13152_y_35872.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_35904.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_36224.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12672_y_36192.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_35968.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_35680.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_35648.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_36096.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_35616.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_36224.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_35776.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_36064.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_35936.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_35808.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_36224.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12672_y_36320.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_36032.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_36128.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12128_y_34752.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_36288.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_35904.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_36160.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_36192.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13184_y_35968.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_35808.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13184_y_35936.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12800_y_36192.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12096_y_34560.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12096_y_34528.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13120_y_36064.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_36032.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12128_y_34784.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12672_y_36288.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_36224.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12672_y_36256.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_35584.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_35616.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12096_y_34720.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12992_y_36352.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13024_y_36160.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12128_y_34624.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_36416.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_36384.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12640_y_36352.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_35712.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12960_y_36096.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12960_y_36352.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_36384.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13024_y_36352.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12800_y_36448.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_36448.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12640_y_36288.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_36128.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12064_y_34560.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_35808.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12096_y_34592.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12960_y_36448.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12864_y_35616.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12672_y_36224.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_36000.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12992_y_35840.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_36096.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12992_y_36416.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12864_y_36352.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12064_y_34720.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12448_y_36064.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_35648.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_36352.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12448_y_36096.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13120_y_35872.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_36192.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_36256.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_35840.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_36128.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_35872.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_36064.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12992_y_36032.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12864_y_36320.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12640_y_36224.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_35968.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12640_y_35808.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_36064.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12896_y_36224.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_35744.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12992_y_36128.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_36352.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12992_y_36096.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13312_y_36224.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_35872.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12096_y_34624.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_35872.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12992_y_36384.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_35936.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12448_y_35936.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_36192.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12128_y_34720.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_36096.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_36352.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_36128.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12448_y_35488.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12800_y_36512.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13376_y_36192.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_36096.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_36160.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13376_y_36128.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13024_y_36064.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12864_y_36256.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12448_y_36128.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12800_y_36480.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_36480.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12864_y_36480.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12800_y_36544.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_36032.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13024_y_36384.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12672_y_36096.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13024_y_36320.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_36224.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12896_y_36256.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13344_y_36288.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12128_y_34656.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12864_y_36544.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_36064.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_35936.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_35776.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12992_y_36064.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13120_y_36000.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_35744.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12960_y_36416.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12864_y_36224.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_36192.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12640_y_35840.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12640_y_35872.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_36160.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12160_y_34720.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_36192.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13024_y_35872.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_36096.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_36480.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12128_y_34592.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_35936.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_36096.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13376_y_36160.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12864_y_36288.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_35872.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12064_y_34688.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_36160.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12896_y_35616.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_36096.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_35616.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_35840.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_36000.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_36032.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13024_y_35840.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_35808.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_35904.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13024_y_36032.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12096_y_34752.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_36032.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13120_y_35968.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12064_y_34624.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12864_y_35680.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12672_y_35808.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12160_y_34688.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12448_y_35968.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_36000.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12800_y_36416.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12800_y_36288.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_36352.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13024_y_36096.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12960_y_36064.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12128_y_34688.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13152_y_35936.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13152_y_36000.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12800_y_36224.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13344_y_36192.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_36096.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_35744.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12064_y_34592.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12096_y_34688.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_35840.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12992_y_36320.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_35584.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12096_y_34656.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13024_y_36416.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_35680.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_36000.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_35936.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12096_y_34784.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_35936.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_35648.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12864_y_35712.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_36416.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_36256.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_36416.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_36192.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12640_y_36320.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12896_y_36288.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_36064.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_36128.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13024_y_36128.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_36096.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12672_y_36128.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12800_y_36384.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12160_y_34656.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12992_y_36448.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12800_y_36352.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_36000.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_36320.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_35840.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12064_y_34528.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_35936.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_36416.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_36320.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12992_y_35872.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_36256.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_36544.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13120_y_36032.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_36224.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_35904.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_36064.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_36160.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_36256.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_35488.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_36448.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_36288.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_35936.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_36096.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_36256.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_36128.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13280_y_36288.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_35968.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_35680.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_36160.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12864_y_35648.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12960_y_36384.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_36288.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_36128.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_36512.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12672_y_36160.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13344_y_36224.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_36128.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_36128.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_36480.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_36224.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12640_y_36256.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13184_y_35904.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_36256.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_35776.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_36160.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_35840.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12640_y_35904.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_35968.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_35968.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13344_y_36160.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_35968.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_35872.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_36032.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13312_y_36288.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_35776.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_35808.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13184_y_35872.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12672_y_35904.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13152_y_35968.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12448_y_35904.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12800_y_36256.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_36288.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_36448.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13024_y_36000.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_36448.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12448_y_36000.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12448_y_36032.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12672_y_35872.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12800_y_36320.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_36384.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13152_y_35904.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_36256.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16416.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16544.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16384_y_24352.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18624_y_16544.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16288_y_24448.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_17216.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_15968.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_17056.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18400_y_16608.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18368_y_16032.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16544_y_24768.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18464_y_16288.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_16672.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_16672.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17024_y_24512.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18784_y_16224.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17664_y_15488.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_17440.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17536_y_15552.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16384.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_15904.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16320.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_17472.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16864.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18368_y_16192.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17632_y_15104.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_17024.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_16704.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_15648.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_16416.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_16832.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16800.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_17216.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_15616.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_16288.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_15680.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_16480.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_16480.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_16896.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16320_y_25056.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18720_y_16448.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17344_y_28160.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_16704.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_16672.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16480.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_16928.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17664_y_15520.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16480_y_24992.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18368_y_16608.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_17536.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18368_y_16640.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_17216.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_15872.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19264_y_16352.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17088_y_24448.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18432_y_16192.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16768.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18720_y_16160.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16480.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17088_y_24512.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19360_y_16736.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_17024.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16160.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18432_y_16352.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_17088.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16160.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19264_y_16256.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16896.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16544.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18784_y_16672.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_16416.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16320.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_16640.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_17152.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_17248.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19264_y_16640.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_17024.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16320_y_24448.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18784_y_16800.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_16448.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16768_y_25152.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16736_y_25152.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16704_y_24640.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16064.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17600_y_15488.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_17376.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16608_y_24512.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16864_y_24288.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17120_y_25088.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16544_y_24800.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18304_y_16064.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17056_y_24512.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16096.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16384.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17280_y_24768.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16288.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_17056.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16768_y_24928.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16448.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_16320.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_17024.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_17184.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16480.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16288_y_25024.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16352.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_15648.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16672.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18720_y_16608.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16320.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16352_y_25024.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_17184.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_17056.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16544.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_16480.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16288.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_15712.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18720_y_16736.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16576.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_17152.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18624_y_16480.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_17216.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16544.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16800.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18432_y_15744.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16096.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16256_y_24448.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16480_y_24832.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16608_y_25024.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_16832.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16352.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18624_y_16448.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_15680.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16480_y_24800.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16992_y_24448.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18784_y_16768.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_16352.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16224.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_17280.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16832_y_24288.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_17408.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16192.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19264_y_16960.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16608_y_24608.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16000.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_17088.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17248_y_24768.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_16640.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16832.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_17504.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_16800.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17632_y_15520.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16832.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_16768.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_16992.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_17248.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_17504.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16256_y_24480.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_16000.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_16576.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16320_y_24608.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16416.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17280_y_28064.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16032.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17536_y_25120.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_17152.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_16192.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17408_y_24896.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19296_y_16608.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_16096.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18432_y_16640.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19264_y_16992.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19264_y_16704.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16864.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_17056.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18624_y_17120.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16768.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16352_y_24512.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_15648.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_16896.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18432_y_16704.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_17088.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16256.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19360_y_16672.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16768_y_25184.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18720_y_16640.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_16288.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_16608.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_17184.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16512.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16576.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16160_y_24864.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_16032.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16608_y_24640.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16416.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16640_y_25440.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17280_y_28192.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16800_y_25120.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_16608.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_17344.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16224.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16384.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_16832.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_16736.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18368_y_16672.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16416.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16608_y_24352.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_16480.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18400_y_16736.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17600_y_15392.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16320_y_24480.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_15968.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16640_y_24576.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_16736.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16992_y_24960.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16672_y_24512.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16512.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_16640.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_17120.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18336_y_16064.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_16096.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16128_y_24896.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16512.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_17440.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_17248.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16544_y_25024.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16064.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_17440.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_17536.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17440_y_24928.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_17216.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16864_y_24352.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16224.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_16768.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_17184.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_16064.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18368_y_16000.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_17344.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18720_y_15680.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16672_y_24928.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18464_y_16672.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17536_y_15520.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16704_y_24608.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16576.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16672.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16736.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18464_y_16608.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16544.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18784_y_16416.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17312_y_28128.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_17312.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_17120.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16448_y_24768.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_16128.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16640_y_24640.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16480_y_25024.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_17152.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_17120.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_17280.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_15584.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_15648.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17408_y_24928.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16832.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17344_y_28128.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_17120.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16736.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16512.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_17184.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_17440.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_16864.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18400_y_16672.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18784_y_16960.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18784_y_15968.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16512.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16800.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_15680.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_17024.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18720_y_16672.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16672.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_15616.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18784_y_16192.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18336_y_16160.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_17056.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16800.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16864.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17024_y_24960.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16640.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16320_y_24640.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17664_y_15328.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_17408.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_15712.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16736_y_24608.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_16704.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_16512.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16928.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_17088.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_16480.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16704.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_16960.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16960.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16096.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16448_y_25024.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_16384.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16832.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17376_y_24896.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16736.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_15648.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_16064.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_15584.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16192.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_16544.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_16864.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16768_y_25120.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16608.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_17248.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16384_y_24640.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_17184.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16512.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_16576.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16736.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_17088.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_17376.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_17088.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_16544.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_16064.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_17056.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19264_y_16480.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16512_y_24352.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17568_y_15488.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18464_y_16128.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16992.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16704.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17664_y_15456.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16480_y_25056.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_17088.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16256.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19296_y_16832.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_16768.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16928.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_17344.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_17120.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16608.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_16768.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16352_y_24608.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17664_y_15104.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16320.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16800_y_24928.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16672_y_24960.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17696_y_15488.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_15744.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_17344.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16704_y_24960.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17184_y_25120.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_16064.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16992_y_24480.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_15776.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16448.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18784_y_15744.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_16672.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_16096.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19328_y_16800.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17856_y_15296.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_16448.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18720_y_15712.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_15904.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16384_y_24512.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16992.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16512.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_16640.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16096.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_16736.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16512_y_24800.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_17216.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_17056.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16608_y_25120.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_17280.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_16576.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16768.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16416_y_24320.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_17376.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16576_y_24768.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18720_y_17440.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17312_y_28032.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16448.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_16992.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17568_y_25216.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16320.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_17120.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_16256.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19264_y_16512.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_17024.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17056_y_24480.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16832_y_24864.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16704.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18368_y_16736.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16288_y_24416.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_16704.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16608_y_24672.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16192.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19264_y_16768.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18720_y_16256.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16960.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_17472.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16320.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_15776.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18784_y_17440.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16192.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_17216.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16544.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19264_y_16576.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16064.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_17376.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16512.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17152_y_25088.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_16608.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_16032.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_16448.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_16544.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16576.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_16768.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_16224.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_17312.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16608.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_17056.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_16448.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19264_y_16384.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18624_y_16416.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_16416.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16576_y_24608.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_16512.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_17248.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16896.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_17056.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17248_y_24800.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_16928.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16896.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16416.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16000.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16640_y_24384.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_16416.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17632_y_15072.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_16480.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_16800.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16608.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_17408.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_16512.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17600_y_15136.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_16800.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_15936.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17600_y_15072.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_16000.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16480.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18624_y_16640.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_16640.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_16000.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16864_y_24320.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_17024.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_16672.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16800_y_25152.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16608.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16896_y_24576.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16480.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_17024.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_16672.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_17440.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16096.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16384.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_16384.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16352.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16672.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17664_y_15200.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_16448.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17120_y_25152.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17568_y_15392.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16416_y_24352.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16480.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16512_y_24320.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16096.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_17120.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16480_y_24352.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_16448.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_16960.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_16960.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_17248.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16672.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16576.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_16640.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16576_y_25024.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16032.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_16704.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_15648.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16448_y_24384.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16608_y_25088.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17568_y_25152.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_17216.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18336_y_16640.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16928.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_16576.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17664_y_15424.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16992.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16288.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_16448.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17664_y_15072.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16992.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18624_y_17280.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18400_y_16352.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16352.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_16512.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17696_y_15232.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_17152.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_16640.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_16288.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17664_y_15136.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_16576.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18464_y_15712.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16576.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_15712.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_16704.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18432_y_15680.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19296_y_16544.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_17248.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16544_y_25056.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18624_y_16832.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16288_y_24480.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16384.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18624_y_17248.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16128.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_17440.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17696_y_15520.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16512_y_25024.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16064.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18432_y_16128.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_17184.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_16416.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17280_y_24800.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18432_y_16256.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18624_y_16096.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19328_y_16832.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16384_y_24320.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_16736.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17600_y_25216.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_17248.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18720_y_16576.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_17280.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_16768.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_16672.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16960.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16768.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16640_y_25408.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16736_y_24640.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16288_y_25056.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_17152.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_16224.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_17152.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18720_y_16480.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_17024.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16928.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16736_y_24576.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16640_y_25056.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16512_y_24768.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16640.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16736.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_15680.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16896_y_24320.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_16864.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16896.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18464_y_16448.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17056_y_24448.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16704.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_16160.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16672_y_25088.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16864.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17600_y_15360.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_16032.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18368_y_16704.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16608.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_15936.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18464_y_17088.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17632_y_15200.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_15584.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19328_y_16768.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16736.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_16448.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16800.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_16416.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16288_y_25088.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16480_y_24768.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19264_y_16832.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_17408.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16704.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18464_y_16512.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16736.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_17312.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16288_y_24512.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_17408.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_17024.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16640_y_24608.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16128.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16672_y_24640.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17152_y_25120.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17024_y_24928.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16672.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17568_y_25120.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16384.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16512.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_17184.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16960.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16608_y_24384.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_16352.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_16160.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17632_y_15488.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_15648.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_17056.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_15744.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16320_y_25024.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16736_y_24672.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17632_y_15424.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_16800.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18464_y_16064.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16608_y_24992.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_17120.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16608.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16640.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19296_y_16736.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16896.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18304_y_16032.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17632_y_15136.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_17056.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_16128.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_16032.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18624_y_17408.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_16544.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16608.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18464_y_17024.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_16736.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17568_y_15360.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16640.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16416.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_17280.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_17088.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16544.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_15968.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_16288.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_17440.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16672_y_24608.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16448_y_24352.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18336_y_16032.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_16384.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19360_y_16768.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18624_y_17184.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16832.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16992.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_17152.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16160.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_17504.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_15936.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16608_y_24960.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_16672.png',\n", + " ...]" ] }, - "execution_count": 11, + "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "full_dataset.y_size" + "print(full_dataset._y_array, full_dataset._y_array.shape)\n", + "print(np.unique(full_dataset.y_array.numpy(), return_counts=True))\n", + "print(np.unique(full_dataset._metadata_df['split'], return_counts=True))\n", + "\n", + "#full_dataset._input_array" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# pd.read_csv(os.path.join('data/camelyon17_v1.0/metadata.csv'), index_col=0, dtype={'patient': 'str'})" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "import copy\n", + "full_dataset_camelyon17 = copy.deepcopy(full_dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "image_base None\n" + ] + } + ], + "source": [ + "supported.datasets[config_encode.dataset]\n", + "print(config_camelyon.train_transform, config_encode.train_transform)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": null, @@ -525,84 +1893,7 @@ "algorithm = initialize_algorithm(\n", " config=config,\n", " datasets=datasets,\n", - " train_grouper=train_grouper)\n", - "\n", - "if not config.eval_only:\n", - " ## Load saved results if resuming\n", - " resume_success = False\n", - " if resume:\n", - " save_path = os.path.join(config.log_dir, 'last_model.pth')\n", - " if not os.path.exists(save_path):\n", - " epochs = [\n", - " int(file.split('_')[0])\n", - " for file in os.listdir(config.log_dir) if file.endswith('.pth')]\n", - " if len(epochs) > 0:\n", - " latest_epoch = max(epochs)\n", - " save_path = os.path.join(config.log_dir, f'{latest_epoch}_model.pth')\n", - " try:\n", - " prev_epoch, best_val_metric = load(algorithm, save_path)\n", - " epoch_offset = prev_epoch + 1\n", - " logger.write(f'Resuming from epoch {epoch_offset} with best val metric {best_val_metric}')\n", - " resume_success = True\n", - " except FileNotFoundError:\n", - " pass\n", - "\n", - " if resume_success == False:\n", - " epoch_offset=0\n", - " best_val_metric=None\n", - "\n", - "\n", - " train(\n", - " algorithm=algorithm,\n", - " datasets=datasets,\n", - " general_logger=logger,\n", - " config=config,\n", - " epoch_offset=epoch_offset,\n", - " best_val_metric=best_val_metric)\n", - "else:\n", - " if config.eval_epoch is None:\n", - " eval_model_path = os.path.join(config.log_dir, 'best_model.pth')\n", - " else:\n", - " eval_model_path = os.path.join(config.log_dir, f'{config.eval_epoch}_model.pth')\n", - " best_epoch, best_val_metric = load(algorithm, eval_model_path)\n", - " if config.eval_epoch is None:\n", - " epoch = best_epoch\n", - " else:\n", - " epoch = config.eval_epoch\n", - " evaluate(\n", - " algorithm=algorithm,\n", - " datasets=datasets,\n", - " epoch=epoch,\n", - " general_logger=logger,\n", - " config=config)\n", - "\n", - "logger.close()\n", - "for split in datasets:\n", - " datasets[split]['eval_logger'].close()\n", - " datasets[split]['algo_logger'].close()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Initialize dataset object" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np, pandas as pd, os, time, torch, torchvision\n", - "data_dir = '/oak/stanford/groups/akundaje/abalsubr/DREAM/wilds/codalab_archive/'\n", - "tf = 'MAX'\n", - "itime = time.time()\n", - "train_chr = pd.read_csv(os.path.join(data_dir, 'labels/{}.train.labels.tsv.gz'.format(tf)), sep='\\t')\n", - "print(time.time() - itime)\n", - "val_chr = pd.read_csv(os.path.join(data_dir, 'labels/{}.val.labels.tsv.gz'.format(tf)), sep='\\t')\n", - "print(time.time() - itime)" + " train_grouper=train_grouper)" ] }, { @@ -676,6 +1967,95 @@ " print(ct, time.time() - itime)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Train/eval" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if not config.eval_only:\n", + " ## Load saved results if resuming\n", + " resume_success = False\n", + " if resume:\n", + " save_path = os.path.join(config.log_dir, 'last_model.pth')\n", + " if not os.path.exists(save_path):\n", + " epochs = [\n", + " int(file.split('_')[0])\n", + " for file in os.listdir(config.log_dir) if file.endswith('.pth')]\n", + " if len(epochs) > 0:\n", + " latest_epoch = max(epochs)\n", + " save_path = os.path.join(config.log_dir, f'{latest_epoch}_model.pth')\n", + " try:\n", + " prev_epoch, best_val_metric = load(algorithm, save_path)\n", + " epoch_offset = prev_epoch + 1\n", + " logger.write(f'Resuming from epoch {epoch_offset} with best val metric {best_val_metric}')\n", + " resume_success = True\n", + " except FileNotFoundError:\n", + " pass\n", + "\n", + " if resume_success == False:\n", + " epoch_offset=0\n", + " best_val_metric=None\n", + "\n", + "\n", + " train(\n", + " algorithm=algorithm,\n", + " datasets=datasets,\n", + " general_logger=logger,\n", + " config=config,\n", + " epoch_offset=epoch_offset,\n", + " best_val_metric=best_val_metric)\n", + "else:\n", + " if config.eval_epoch is None:\n", + " eval_model_path = os.path.join(config.log_dir, 'best_model.pth')\n", + " else:\n", + " eval_model_path = os.path.join(config.log_dir, f'{config.eval_epoch}_model.pth')\n", + " best_epoch, best_val_metric = load(algorithm, eval_model_path)\n", + " if config.eval_epoch is None:\n", + " epoch = best_epoch\n", + " else:\n", + " epoch = config.eval_epoch\n", + " evaluate(\n", + " algorithm=algorithm,\n", + " datasets=datasets,\n", + " epoch=epoch,\n", + " general_logger=logger,\n", + " config=config)\n", + "\n", + "logger.close()\n", + "for split in datasets:\n", + " datasets[split]['eval_logger'].close()\n", + " datasets[split]['algo_logger'].close()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": 78, @@ -787,301 +2167,6 @@ "count_parameters(model)\n", "lst" ] - }, - { - "cell_type": "code", - "execution_count": 48, - "metadata": {}, - "outputs": [ - { - "ename": "AttributeError", - "evalue": "'module' object has no attribute 'isin'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mtr_chrs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m'chr2'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr9'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr11'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mte_chrs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m'chr1'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr8'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr21'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mtraining_df\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'chr'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtr_chrs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mval_df\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mval_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mval_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'chr'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mte_chrs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mall_df\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtraining_df\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_df\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mAttributeError\u001b[0m: 'module' object has no attribute 'isin'" - ] - } - ], - "source": [ - "tr_chrs = ['chr2', 'chr9', 'chr11']\n", - "te_chrs = ['chr1', 'chr8', 'chr21']\n", - "training_df = train_chr[np.isin(train_chr['chr'], tr_chrs)]\n", - "val_df = val_chr[np.isin(val_chr['chr'], te_chrs)]\n", - "all_df = pd.concat([training_df, val_df])\n", - "\n", - "#filter_msk = all_df['start'] >= 0\n", - "filter_msk = all_df['start']%1000 == 0\n", - "all_df = all_df[filter_msk]" - ] - }, - { - "cell_type": "code", - "execution_count": 49, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "1.12.1\n" - ] - } - ], - "source": [ - "print(np.__version__)" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/users/abalsubr/anaconda2/envs/scs3/lib/python3.6/site-packages/ipykernel_launcher.py:6: SettingWithCopyWarning: \n", - "A value is trying to be set on a copy of a slice from a DataFrame.\n", - "Try using .loc[row_indexer,col_indexer] = value instead\n", - "\n", - "See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy\n", - " \n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "1.659163236618042\n" - ] - } - ], - "source": [ - "itime = time.time()\n", - "pd_list = []\n", - "for ct in all_celltypes:\n", - " tc_chr = all_df[['chr', 'start', 'stop', ct]]\n", - " tc_chr.columns = ['chr', 'start', 'stop', 'y']\n", - " tc_chr['celltype'] = ct\n", - " pd_list.append(tc_chr)\n", - "metadata_df = pd.concat(pd_list)\n", - "print(time.time() - itime)" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "3.0391879081726074\n" - ] - } - ], - "source": [ - "itime = time.time()\n", - "y_array = metadata_df['y'].replace({'U': 0, 'B': 1, 'A': -1}).values\n", - "non_ambig_mask = (y_array != -1)\n", - "metadata_df['y'] = y_array\n", - "_metadata_df = metadata_df[non_ambig_mask]\n", - "_y_array = torch.LongTensor(y_array[non_ambig_mask])\n", - "print(time.time() - itime)" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "12.390011310577393\n" - ] - } - ], - "source": [ - "itime = time.time()\n", - "chr_ints = _metadata_df['chr'].replace(dict( [(y, x) for x, y in enumerate(metadata_map['chr'])] )).values\n", - "celltype_ints = _metadata_df['celltype'].replace(dict( [(y, x) for x, y in enumerate(metadata_map['celltype'])] )).values\n", - "print(time.time() - itime)" - ] - }, - { - "cell_type": "code", - "execution_count": 53, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/users/abalsubr/anaconda2/envs/scs3/lib/python3.6/site-packages/ipykernel_launcher.py:12: SettingWithCopyWarning: \n", - "A value is trying to be set on a copy of a slice from a DataFrame.\n", - "Try using .loc[row_indexer,col_indexer] = value instead\n", - "\n", - "See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy\n", - " if sys.path[0] == '':\n" - ] - } - ], - "source": [ - "train_chr_mask = np.isin(_metadata_df['chr'], tr_chrs)\n", - "val_chr_mask = np.isin(_metadata_df['chr'], te_chrs)\n", - "train_celltype_mask = np.isin(_metadata_df['celltype'], train_celltypes)\n", - "val_celltype_mask = np.isin(_metadata_df['celltype'], val_celltype)\n", - "test_celltype_mask = np.isin(_metadata_df['celltype'], test_celltype)\n", - "\n", - "split_array = -1*np.ones(_metadata_df.shape[0]).astype(int)\n", - "split_array[np.logical_and(train_chr_mask, train_celltype_mask)] = _split_dict['train']\n", - "split_array[np.logical_and(val_chr_mask, test_celltype_mask)] = _split_dict['test']\n", - "split_array[np.logical_and(val_chr_mask, val_celltype_mask)] = _split_dict['val-ood']\n", - "split_array[np.logical_and(val_chr_mask, train_celltype_mask)] = _split_dict['val-id']\n", - "_metadata_df['split'] = split_array\n", - "_split_array = split_array" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# get_input (idx)" - ] - }, - { - "cell_type": "code", - "execution_count": 153, - "metadata": {}, - "outputs": [], - "source": [ - "idx = 3\n", - "this_metadata = _metadata_df.iloc[idx, :]\n", - "\n", - "itime = time.time()\n", - "flank_size = 400\n", - "interval_start = this_metadata['start'] - flank_size\n", - "interval_end = this_metadata['stop'] + flank_size\n", - "dnase_this = _dnase_allcelltypes[this_metadata['celltype']][this_metadata['chr']][interval_start:interval_end]\n", - "seq_this = _seq_bp[this_metadata['chr']][interval_start:interval_end]\n", - "data = np.column_stack([seq_this, dnase_this])\n", - "# print(time.time() - itime)" - ] - }, - { - "cell_type": "code", - "execution_count": 154, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "4600" - ] - }, - "execution_count": 154, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "data.shape\n", - "interval_end\n", - "# itime = time.time()\n", - "# np.save(os.path.join(data_dir, 'stmp.npy'), sa)\n", - "# print(time.time() - itime)" - ] - }, - { - "cell_type": "code", - "execution_count": 78, - "metadata": {}, - "outputs": [ - { - "ename": "TypeError", - "evalue": "can't convert np.ndarray of type numpy.object_. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool.", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mitime\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m metadata_array = torch.stack(\n\u001b[0;32m----> 3\u001b[0;31m (torch.LongTensor(metadata_df['chr'].values), \n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mLongTensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmetadata_df\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'celltype'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m self._y_array),\n", - "\u001b[0;31mTypeError\u001b[0m: can't convert np.ndarray of type numpy.object_. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool." - ] - } - ], - "source": [ - "itime = time.time()\n", - "metadata_array = torch.stack(\n", - " (torch.LongTensor(chr_ints), \n", - " torch.LongTensor(celltype_ints), \n", - " _y_array),\n", - " dim=1)\n", - "print(time.time() - itime)" - ] - }, - { - "cell_type": "code", - "execution_count": 156, - "metadata": {}, - "outputs": [ - { - "ename": "NameError", - "evalue": "name '_metadata_array' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0m_metadata_array\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;31mNameError\u001b[0m: name '_metadata_array' is not defined" - ] - } - ], - "source": [ - "_metadata_array" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "from examples.models.model_attributes import model_attributes" - ] } ], "metadata": { diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index 327ffb3e..f8b66f25 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -33,25 +33,43 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): self._y_size = 1 self._n_classes = 2 - # self._tr_chrs = ['chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX'] - self._tr_chrs = ['chr2', 'chr9', 'chr11'] - self._te_chrs = ['chr1', 'chr8', 'chr21'] + # self._train_chroms = ['chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX'] + self._train_chroms = ['chr2', 'chr9', 'chr11'] + self._test_chroms = ['chr1', 'chr8', 'chr21'] self._transcription_factor = 'MAX' self._train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562'] self._val_celltype = ['A549'] self._test_celltype = ['GM12878'] + self._all_chroms = self._train_chroms + self._test_chroms self._all_celltypes = self._train_celltypes + self._val_celltype + self._test_celltype - self._metadata_fields = ['chr', 'celltype', 'y'] self._metadata_map = {} - self._metadata_map['chr'] = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX'] + self._metadata_map['chr'] = self._all_chroms #['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX'] self._metadata_map['celltype'] = self._all_celltypes + # Get the splits + if split_scheme=='official': + split_scheme = 'standard' + + self._split_scheme = split_scheme + self._split_dict = { + 'train': 0, + 'id_val': 1, + 'test': 2, + 'val': 3 + } + self._split_names = { + 'train': 'Train', + 'id_val': 'Validation (ID)', + 'test': 'Test', + 'val': 'Validation (OOD)', + } + # Load sequence and DNase features sequence_filename = os.path.join(self._data_dir, 'sequence.npz') seq_arr = np.load(sequence_filename) self._seq_bp = {} - for chrom in seq_arr: + for chrom in self._all_chroms: #seq_arr: self._seq_bp[chrom] = seq_arr[chrom] print(chrom, time.time() - itime) @@ -60,29 +78,27 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): dnase_filename = os.path.join(self._data_dir, '{}_dnase.npz'.format(ct)) dnase_npz_contents = np.load(dnase_filename) self._dnase_allcelltypes[ct] = {} - for chrom in self._seq_bp: + for chrom in self._all_chroms: #self._seq_bp: self._dnase_allcelltypes[ct][chrom] = dnase_npz_contents[chrom] print(ct, time.time() - itime) # Read in metadata dataframe from training+validation data train_regions_labeled = pd.read_csv(os.path.join(self._data_dir, 'labels/{}.train.labels.tsv.gz'.format(self._transcription_factor)), sep='\t') val_regions_labeled = pd.read_csv(os.path.join(self._data_dir, 'labels/{}.val.labels.tsv.gz'.format(self._transcription_factor)), sep='\t') - training_df = train_regions_labeled[np.isin(train_regions_labeled['chr'], self._tr_chrs)] - val_df = val_regions_labeled[np.isin(val_regions_labeled['chr'], self._te_chrs)] + training_df = train_regions_labeled[np.isin(train_regions_labeled['chr'], self._train_chroms)] + val_df = val_regions_labeled[np.isin(val_regions_labeled['chr'], self._test_chroms)] all_df = pd.concat([training_df, val_df]) - print(train_regions_labeled, time.time() - itime) - # Filter by start/stop coordinate if needed - # (TODO: remove for final version) + # Filter by start/stop coordinate if needed (TODO: remove for final version) filter_msk = all_df['start'] >= 0 filter_msk = all_df['start']%1000 == 0 all_df = all_df[filter_msk] pd_list = [] - for ct in self._train_celltypes: + for ct in self._all_celltypes: tc_chr = all_df[['chr', 'start', 'stop', ct]] tc_chr.columns = ['chr', 'start', 'stop', 'y'] - tc_chr['celltype'] = ct + tc_chr.insert(len(tc_chr.columns), 'celltype', ct) pd_list.append(tc_chr) metadata_df = pd.concat(pd_list) @@ -91,33 +107,9 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): non_ambig_mask = (y_array != -1) metadata_df['y'] = y_array self._metadata_df = metadata_df[non_ambig_mask] - self._y_array = torch.LongTensor(y_array[non_ambig_mask]) - chr_ints = self._metadata_df['chr'].replace(dict( [(y, x) for x, y in enumerate(self._metadata_map['chr'])] )).values - celltype_ints = self._metadata_df['celltype'].replace(dict( [(y, x) for x, y in enumerate(self._metadata_map['celltype'])] )).values - self._metadata_array = torch.stack( - (torch.LongTensor(chr_ints), - torch.LongTensor(celltype_ints), - self._y_array), - dim=1) - - # Get the splits - # TODO Extract splits as encoded in split_scheme. Hardcoded here for now. - self._split_scheme = split_scheme - self._split_dict = { - 'train': 0, - 'val-id': 1, - 'test': 2, - 'val-ood': 3 - } - self._split_names = { - 'train': 'Train', - 'val-id': 'Validation (ID)', - 'test': 'Test', - 'val-ood': 'Validation (OOD)', - } - train_regions_mask = np.isin(self._metadata_df['chr'], self._tr_chrs) - val_regions_mask = np.isin(self._metadata_df['chr'], self._te_chrs) + train_regions_mask = np.isin(self._metadata_df['chr'], self._train_chroms) + val_regions_mask = np.isin(self._metadata_df['chr'], self._test_chroms) train_celltype_mask = np.isin(self._metadata_df['celltype'], self._train_celltypes) val_celltype_mask = np.isin(self._metadata_df['celltype'], self._val_celltype) test_celltype_mask = np.isin(self._metadata_df['celltype'], self._test_celltype) @@ -125,17 +117,33 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): split_array = -1*np.ones(self._metadata_df.shape[0]).astype(int) split_array[np.logical_and(train_regions_mask, train_celltype_mask)] = self._split_dict['train'] split_array[np.logical_and(val_regions_mask, test_celltype_mask)] = self._split_dict['test'] - # Validate using test chr, either using a designated validation cell line ('val-ood') or a training cell line ('val-id') - split_array[np.logical_and(val_regions_mask, val_celltype_mask)] = self._split_dict['val-ood'] - split_array[np.logical_and(val_regions_mask, train_celltype_mask)] = self._split_dict['val-id'] + # Validate using test chr, either using a designated validation cell line ('val') or a training cell line ('id_val') + split_array[np.logical_and(val_regions_mask, val_celltype_mask)] = self._split_dict['val'] + split_array[np.logical_and(val_regions_mask, train_celltype_mask)] = self._split_dict['id_val'] + if self._split_scheme=='standard': - self._metadata_df['split'] = split_array - self._split_array = split_array + self._metadata_df.insert(len(self._metadata_df.columns), 'split', split_array) else: raise ValueError(f'Split scheme {self._split_scheme} not recognized') + + self._metadata_df = self._metadata_df[self._metadata_df['split'] != -1] + self._split_array = self._metadata_df['split'].values + + chr_ints = self._metadata_df['chr'].replace(dict( [(y, x) for x, y in enumerate(self._metadata_map['chr'])] )).values + celltype_ints = self._metadata_df['celltype'].replace(dict( [(y, x) for x, y in enumerate(self._metadata_map['celltype'])] )).values + self._y_array = torch.LongTensor(np.array(self._metadata_df['y'])) + + self._metadata_array = torch.stack( + (torch.LongTensor(chr_ints), + torch.LongTensor(celltype_ints), + self._y_array), + dim=1) + self._metadata_fields = ['chr', 'celltype', 'y'] + self._eval_grouper = CombinatorialGrouper( dataset=self, groupby_fields=['celltype']) + self._metric = Accuracy() super().__init__(root_dir, download, split_scheme) @@ -152,8 +160,8 @@ def get_input(self, idx): flank_size = 400 interval_start = this_metadata['start'] - flank_size interval_end = this_metadata['stop'] + flank_size - dnase_this = _dnase_allcelltypes[this_metadata['celltype']][this_metadata['chr']][interval_start:interval_end] - seq_this = _seq_bp[this_metadata['chr']][interval_start:interval_end] + dnase_this = self._dnase_allcelltypes[this_metadata['celltype']][this_metadata['chr']][interval_start:interval_end] + seq_this = self._seq_bp[this_metadata['chr']][interval_start:interval_end] return torch.tensor(np.column_stack([seq_this, dnase_this])) def eval(self, y_pred, y_true, metadata): From ba4be3dba7a209ca300ee8ea70ac8fc98549dd35 Mon Sep 17 00:00:00 2001 From: aikanor Date: Thu, 11 Feb 2021 09:15:50 -0800 Subject: [PATCH 023/244] integration 16/ (with most remaining bugfixes) --- examples/configs/datasets.py | 4 +- examples/models/CNN_genome.py | 1 + examples/models/initializer.py | 1 + examples/sbox_run_expt.ipynb | 1465 +++++--------------------- wilds/datasets/encodetfbs_dataset.py | 14 + 5 files changed, 270 insertions(+), 1215 deletions(-) diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index 58baf1c1..4bd331a0 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -96,13 +96,13 @@ 'train_transform': None, 'eval_transform': None, 'loss_function': 'cross_entropy', - 'groupby_fields': ['celltype'], + 'groupby_fields': ['celltype', 'y'], 'val_metric': 'acc_avg', 'val_metric_decreasing': False, 'optimizer': 'Adam', # 'optimizer_kwargs': { }, 'scheduler': None, - 'batch_size': 128, + 'batch_size': 64, 'lr': 0.001, 'weight_decay': 0.01, 'n_epochs': 1, diff --git a/examples/models/CNN_genome.py b/examples/models/CNN_genome.py index cc464ef0..1c65b567 100644 --- a/examples/models/CNN_genome.py +++ b/examples/models/CNN_genome.py @@ -7,6 +7,7 @@ class Beagle(nn.Module): """ Neural net models over genomic sequence. Adapted from https://github.com/kundajelab/ChromDragoNN + Input: - s (Tensor): float torch tensor of shape (N, 5, 1000, 1) with batch size N. diff --git a/examples/models/initializer.py b/examples/models/initializer.py index cea5ebfc..fb77a5ea 100644 --- a/examples/models/initializer.py +++ b/examples/models/initializer.py @@ -6,6 +6,7 @@ from models.gnn import GINVirtual def initialize_model(config, d_out): + print('Dout: {}'.format(d_out)) if config.model == 'resnet18_ms': # multispectral resnet 18 model = ResNet18(num_classes=d_out, **config.model_kwargs) diff --git a/examples/sbox_run_expt.ipynb b/examples/sbox_run_expt.ipynb index 56a9f6a2..e50f790b 100644 --- a/examples/sbox_run_expt.ipynb +++ b/examples/sbox_run_expt.ipynb @@ -11,7 +11,28 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "ename": "ModuleNotFoundError", + "evalue": "No module named 'psutil'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpsutil\u001b[0m\u001b[0;34m;\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpsutil\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mProcess\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgetpid\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmemory_info\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrss\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0;36m1024\u001b[0m \u001b[0;34m**\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'psutil'" + ] + } + ], + "source": [ + "import os, psutil; print(psutil.Process(os.getpid()).memory_info().rss / 1024 ** 2)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -38,16 +59,16 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" + "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" ] }, - "execution_count": 2, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -142,7 +163,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -154,19 +175,13 @@ "config_encode = parser.parse_args(argstr_encode.split())\n", "config_encode = populate_defaults(config_encode)\n", "\n", - "config = config_camelyon\n", - "#config = config_encode" + "config = config_camelyon" ] }, { "cell_type": "code", - "execution_count": 6, - "metadata": { - "collapsed": true, - "jupyter": { - "outputs_hidden": true - } - }, + "execution_count": 5, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -275,6 +290,19 @@ " dataset=full_dataset)" ] }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "import copy\n", + "full_dataset_camelyon17 = copy.deepcopy(full_dataset)\n", + "\n", + "# supported.datasets[config_encode.dataset]\n", + "# print(config_camelyon.train_transform, config_encode.train_transform)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -284,11 +312,10 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 30, "metadata": { - "collapsed": true, "jupyter": { - "outputs_hidden": true + "source_hidden": true } }, "outputs": [ @@ -296,46 +323,19 @@ "name": "stdout", "output_type": "stream", "text": [ - "chr2 3.817250967025757\n", - "chr9 6.033524990081787\n", - "chr11 8.150986433029175\n", - "chr1 12.036555290222168\n", - "chr8 14.306443929672241\n", - "chr21 15.043241739273071\n", - "H1-hESC 21.61008930206299\n", - "HCT116 28.000329971313477\n", - "HeLa-S3 34.6184778213501\n", - "HepG2 41.089255809783936\n", - "K562 47.70136523246765\n", - "A549 54.22390341758728\n", - "GM12878 60.65142226219177\n", - " chr start stop A549 GM12878 H1-hESC HCT116 HeLa-S3 \\\n", - "0 chr10 600 800 U U U U U \n", - "1 chr10 650 850 U U U U U \n", - "2 chr10 700 900 U U U U U \n", - "3 chr10 750 950 U U U U U \n", - "4 chr10 800 1000 U U U U U \n", - "... ... ... ... ... ... ... ... ... \n", - "51676731 chrX 155269750 155269950 U U U U U \n", - "51676732 chrX 155269800 155270000 U U U U U \n", - "51676733 chrX 155269850 155270050 U U U U U \n", - "51676734 chrX 155269900 155270100 U U U U U \n", - "51676735 chrX 155269950 155270150 U U U U U \n", - "\n", - " HepG2 K562 \n", - "0 U U \n", - "1 U U \n", - "2 U U \n", - "3 U U \n", - "4 U U \n", - "... ... ... \n", - "51676731 U U \n", - "51676732 U U \n", - "51676733 U U \n", - "51676734 U U \n", - "51676735 U U \n", - "\n", - "[51676736 rows x 10 columns] 130.07371044158936\n" + "chr2 3.657395362854004\n", + "chr9 5.770605564117432\n", + "chr11 7.801896095275879\n", + "chr1 11.56663990020752\n", + "chr8 13.764073133468628\n", + "chr21 14.483267068862915\n", + "H1-hESC 20.850953817367554\n", + "HCT116 27.05355429649353\n", + "HeLa-S3 33.51919412612915\n", + "HepG2 39.89570116996765\n", + "K562 46.36982774734497\n", + "A549 52.82617139816284\n", + "GM12878 59.167165994644165\n" ] } ], @@ -417,9 +417,9 @@ "all_df = pd.concat([training_df, val_df])\n", "\n", "# Filter by start/stop coordinate if needed (TODO: remove for final version)\n", - "filter_msk = all_df['start'] >= 0\n", - "filter_msk = all_df['start']%1000 == 0\n", - "all_df = all_df[filter_msk]\n", + "# filter_msk = all_df['start'] >= 0\n", + "# filter_msk = all_df['start']%1000 == 0\n", + "# all_df = all_df[filter_msk]\n", "\n", "pd_list = []\n", "for ct in _all_celltypes:\n", @@ -427,20 +427,39 @@ " tc_chr.columns = ['chr', 'start', 'stop', 'y']\n", " tc_chr.insert(len(tc_chr.columns), 'celltype', ct)\n", " pd_list.append(tc_chr)\n", - "metadata_df = pd.concat(pd_list)" + "metadata_df = pd.concat(pd_list)\n", + "\n", + "# Get the y values, and remove ambiguous labels by default.\n", + "y_array = metadata_df['y'].replace({'U': 0, 'B': 1, 'A': -1}).values\n", + "non_ambig_mask = (y_array != -1)\n", + "metadata_df['y'] = y_array\n", + "_metadata_df = metadata_df[non_ambig_mask]" ] }, { "cell_type": "code", - "execution_count": 131, - "metadata": {}, + "execution_count": 35, + "metadata": { + "jupyter": { + "source_hidden": true + } + }, "outputs": [], "source": [ - "# Get the y values, and remove ambiguous labels by default.\n", - "y_array = metadata_df['y'].replace({'U': 0, 'B': 1, 'A': -1}).values\n", - "non_ambig_mask = (y_array != -1)\n", - "metadata_df['y'] = y_array\n", - "_metadata_df = metadata_df[non_ambig_mask]\n", + "samp_ndces = []\n", + "itime = time.time()\n", + "for ct in _all_celltypes:\n", + " neg_msk = np.logical_and((_metadata_df['celltype'] == ct), (_metadata_df['y'] == 0))\n", + " pos_msk = np.logical_and((_metadata_df['celltype'] == ct), (_metadata_df['y'] == 1))\n", + " neg_ndces = np.where(neg_msk)[0]\n", + " pos_ndces = np.where(pos_msk)[0]\n", + " np.random.seed(42)\n", + " samp_neg_ndces = np.random.choice(neg_ndces, size=len(pos_ndces), replace=False)\n", + " samp_ndces.extend(samp_neg_ndces)\n", + " samp_ndces.extend(pos_ndces)\n", + " print(ct, time.time() - itime)\n", + "\n", + "_metadata_df = _metadata_df.iloc[samp_ndces, :]\n", "\n", "train_regions_mask = np.isin(_metadata_df['chr'], _train_chroms)\n", "val_regions_mask = np.isin(_metadata_df['chr'], _test_chroms)\n", @@ -472,7 +491,7 @@ " torch.LongTensor(celltype_ints), \n", " _y_array),\n", " dim=1)\n", - "_metadata_fields = ['chr', 'celltype', 'y']\n" + "_metadata_fields = ['chr', 'celltype', 'y']" ] }, { @@ -484,7 +503,7 @@ }, { "cell_type": "code", - "execution_count": 138, + "execution_count": 23, "metadata": {}, "outputs": [], "source": [ @@ -523,8 +542,8 @@ " self._y_size = 1\n", " self._n_classes = 2\n", " \n", - " # self._train_chroms = ['chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX']\n", - " self._train_chroms = ['chr2', 'chr9', 'chr11']\n", + " self._train_chroms = ['chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX']\n", + " # self._train_chroms = ['chr2', 'chr9', 'chr11']\n", " self._test_chroms = ['chr1', 'chr8', 'chr21']\n", " self._transcription_factor = 'MAX'\n", " self._train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']\n", @@ -580,9 +599,9 @@ " all_df = pd.concat([training_df, val_df])\n", " \n", " # Filter by start/stop coordinate if needed (TODO: remove for final version)\n", - " filter_msk = all_df['start'] >= 0\n", - " filter_msk = all_df['start']%1000 == 0\n", - " all_df = all_df[filter_msk]\n", + " # filter_msk = all_df['start'] >= 0\n", + " # filter_msk = all_df['start']%1000 == 0\n", + " # all_df = all_df[filter_msk]\n", " \n", " pd_list = []\n", " for ct in self._all_celltypes:\n", @@ -598,6 +617,20 @@ " metadata_df['y'] = y_array\n", " self._metadata_df = metadata_df[non_ambig_mask]\n", " \n", + " samp_ndces = []\n", + " itime = time.time()\n", + " for ct in self._all_celltypes:\n", + " neg_msk = np.logical_and((self._metadata_df['celltype'] == ct), (self._metadata_df['y'] == 0))\n", + " pos_msk = np.logical_and((self._metadata_df['celltype'] == ct), (self._metadata_df['y'] == 1))\n", + " neg_ndces = np.where(neg_msk)[0]\n", + " pos_ndces = np.where(pos_msk)[0]\n", + " np.random.seed(42)\n", + " samp_neg_ndces = np.random.choice(neg_ndces, size=len(pos_ndces), replace=False)\n", + " samp_ndces.extend(samp_neg_ndces)\n", + " samp_ndces.extend(pos_ndces)\n", + " print(ct, time.time() - itime)\n", + " self._metadata_df = self._metadata_df.iloc[samp_ndces, :]\n", + " \n", " train_regions_mask = np.isin(self._metadata_df['chr'], self._train_chroms)\n", " val_regions_mask = np.isin(self._metadata_df['chr'], self._test_chroms)\n", " train_celltype_mask = np.isin(self._metadata_df['celltype'], self._train_celltypes)\n", @@ -663,26 +696,43 @@ }, { "cell_type": "code", - "execution_count": 139, + "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "chr2 3.836134910583496\n", - "chr9 6.034452438354492\n", - "chr11 8.16244888305664\n", - "chr1 12.12421178817749\n", - "chr8 14.44963550567627\n", - "chr21 15.212148189544678\n", - "H1-hESC 21.892271518707275\n", - "HCT116 28.37229895591736\n", - "HeLa-S3 35.18828296661377\n", - "HepG2 41.83891773223877\n", - "K562 48.590251445770264\n", - "A549 55.3311812877655\n", - "GM12878 61.93817687034607\n" + "chr2 3.718320846557617\n", + "chr3 6.73882269859314\n", + "chr4 9.651247501373291\n", + "chr5 12.439628839492798\n", + "chr6 15.05026388168335\n", + "chr7 17.475954055786133\n", + "chr9 19.6206693649292\n", + "chr10 21.68758535385132\n", + "chr11 23.74817419052124\n", + "chr12 25.81403160095215\n", + "chr13 27.559557676315308\n", + "chr14 29.18643832206726\n", + "chr15 30.739391565322876\n", + "chr16 32.11144256591797\n", + "chr17 33.348127126693726\n", + "chr18 34.53834342956543\n", + "chr19 35.434733629226685\n", + "chr20 36.399296283721924\n", + "chr22 37.1924102306366\n", + "chrX 39.56284308433533\n", + "chr1 43.3526566028595\n", + "chr8 45.583492040634155\n", + "chr21 46.311339378356934\n", + "H1-hESC 66.45292735099792\n", + "HCT116 86.06067085266113\n", + "HeLa-S3 106.47142815589905\n", + "HepG2 126.59437656402588\n", + "K562 146.93650436401367\n", + "A549 167.19306707382202\n", + "GM12878 187.4349775314331\n" ] } ], @@ -696,1068 +746,63 @@ }, { "cell_type": "code", - "execution_count": 140, + "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "<__main__.EncodeTFBSDataset at 0x7fe6b69d33a0>" + "(array(['A549', 'GM12878', 'H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562'],\n", + " dtype=object),\n", + " array([ 5118, 1702, 8460, 12806, 8348, 11774, 12518]))" ] }, - "execution_count": 140, + "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "full_dataset_encode" + "np.unique(full_dataset_encode._metadata_df['celltype'], return_counts=True)" ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 17, "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "(array([0, 1]), array([227977, 227977]))" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([0, 0, 0, ..., 1, 1, 1]) torch.Size([60726])\n", + "(array([0, 1]), array([30426, 30300]))\n", + "(array([0, 1, 2, 3]), array([28556, 25350, 1702, 5118]))\n" + ] } ], - "source": [] + "source": [ + "full_dataset = copy.deepcopy(full_dataset_encode)\n", + "print(full_dataset._y_array, full_dataset._y_array.shape)\n", + "print(np.unique(full_dataset.y_array.numpy(), return_counts=True))\n", + "print(np.unique(full_dataset._metadata_df['split'], return_counts=True))\n", + "\n", + "#full_dataset._input_array" + ] }, { "cell_type": "code", - "execution_count": 17, - "metadata": { - "collapsed": true, - "jupyter": { - "outputs_hidden": true - } - }, + "execution_count": 9, + "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "tensor([1, 1, 1, ..., 0, 0, 0]) torch.Size([455954])\n" + "tensor([0, 0, 0, ..., 0, 0, 0]) torch.Size([5568233])\n", + "(array([0, 1]), array([5537933, 30300]))\n", + "(array([0, 1, 2, 3]), array([2533595, 2163528, 437124, 433986]))\n" ] - }, - { - "data": { - "text/plain": [ - "['patches/patient_004_node_4/patch_patient_004_node_4_x_3328_y_21792.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3200_y_22272.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3168_y_22272.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3328_y_21760.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3232_y_22240.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3168_y_22240.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3136_y_22208.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_2656_y_18880.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3136_y_22240.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3296_y_21856.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3296_y_21792.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3360_y_21824.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3360_y_21760.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3296_y_21824.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3328_y_21824.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_2688_y_18912.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3168_y_22176.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_2688_y_18816.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3200_y_22176.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3168_y_22208.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_2688_y_18880.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3296_y_21760.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_2656_y_18848.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3136_y_22272.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3264_y_21856.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3264_y_21824.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_2688_y_18848.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3264_y_21792.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_2656_y_18944.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3200_y_22208.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3232_y_22208.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3200_y_22240.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_2688_y_18944.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3360_y_21792.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_2656_y_18912.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_2656_y_18816.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12960_y_36128.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_35968.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_36512.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_36064.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13312_y_36320.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_35968.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12128_y_34560.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_36384.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_36192.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13120_y_35936.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_36480.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_35680.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_35648.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_36032.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_36416.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12992_y_36160.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12160_y_34752.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_36288.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_36160.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_35744.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12672_y_35840.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_36320.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12896_y_35648.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_35904.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12864_y_36512.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13120_y_35904.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_35872.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_36192.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_36224.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_36416.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_35872.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13152_y_35872.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_35904.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_36224.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12672_y_36192.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_35968.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_35680.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_35648.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_36096.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_35616.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_36224.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_35776.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_36064.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_35936.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_35808.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_36224.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12672_y_36320.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_36032.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_36128.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12128_y_34752.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_36288.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_35904.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_36160.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_36192.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13184_y_35968.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_35808.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13184_y_35936.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12800_y_36192.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12096_y_34560.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12096_y_34528.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13120_y_36064.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_36032.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12128_y_34784.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12672_y_36288.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_36224.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12672_y_36256.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_35584.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_35616.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12096_y_34720.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12992_y_36352.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13024_y_36160.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12128_y_34624.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_36416.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_36384.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12640_y_36352.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_35712.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12960_y_36096.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12960_y_36352.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_36384.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13024_y_36352.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12800_y_36448.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_36448.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12640_y_36288.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_36128.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12064_y_34560.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_35808.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12096_y_34592.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12960_y_36448.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12864_y_35616.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12672_y_36224.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_36000.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12992_y_35840.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_36096.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12992_y_36416.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12864_y_36352.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12064_y_34720.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12448_y_36064.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_35648.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_36352.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12448_y_36096.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13120_y_35872.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_36192.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_36256.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_35840.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_36128.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_35872.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_36064.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12992_y_36032.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12864_y_36320.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12640_y_36224.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_35968.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12640_y_35808.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_36064.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12896_y_36224.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_35744.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12992_y_36128.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_36352.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12992_y_36096.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13312_y_36224.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_35872.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12096_y_34624.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_35872.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12992_y_36384.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_35936.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12448_y_35936.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_36192.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12128_y_34720.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_36096.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_36352.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_36128.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12448_y_35488.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12800_y_36512.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13376_y_36192.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_36096.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_36160.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13376_y_36128.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13024_y_36064.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12864_y_36256.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12448_y_36128.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12800_y_36480.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_36480.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12864_y_36480.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12800_y_36544.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_36032.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13024_y_36384.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12672_y_36096.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13024_y_36320.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_36224.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12896_y_36256.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13344_y_36288.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12128_y_34656.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12864_y_36544.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_36064.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_35936.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_35776.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12992_y_36064.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13120_y_36000.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_35744.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12960_y_36416.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12864_y_36224.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_36192.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12640_y_35840.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12640_y_35872.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_36160.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12160_y_34720.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_36192.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13024_y_35872.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_36096.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_36480.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12128_y_34592.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_35936.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_36096.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13376_y_36160.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12864_y_36288.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_35872.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12064_y_34688.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_36160.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12896_y_35616.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_36096.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_35616.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_35840.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_36000.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_36032.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13024_y_35840.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_35808.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_35904.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13024_y_36032.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12096_y_34752.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_36032.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13120_y_35968.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12064_y_34624.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12864_y_35680.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12672_y_35808.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12160_y_34688.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12448_y_35968.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_36000.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12800_y_36416.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12800_y_36288.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_36352.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13024_y_36096.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12960_y_36064.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12128_y_34688.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13152_y_35936.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13152_y_36000.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12800_y_36224.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13344_y_36192.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_36096.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_35744.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12064_y_34592.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12096_y_34688.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_35840.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12992_y_36320.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_35584.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12096_y_34656.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13024_y_36416.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_35680.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_36000.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_35936.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12096_y_34784.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_35936.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_35648.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12864_y_35712.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_36416.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_36256.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_36416.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_36192.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12640_y_36320.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12896_y_36288.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_36064.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_36128.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13024_y_36128.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_36096.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12672_y_36128.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12800_y_36384.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12160_y_34656.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12992_y_36448.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12800_y_36352.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_36000.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_36320.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_35840.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12064_y_34528.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_35936.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_36416.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_36320.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12992_y_35872.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_36256.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_36544.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13120_y_36032.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_36224.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_35904.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_36064.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_36160.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_36256.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_35488.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_36448.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_36288.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_35936.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_36096.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_36256.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_36128.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13280_y_36288.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_35968.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_35680.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_36160.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12864_y_35648.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12960_y_36384.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_36288.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_36128.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_36512.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12672_y_36160.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13344_y_36224.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_36128.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_36128.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_36480.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_36224.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12640_y_36256.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13184_y_35904.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_36256.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_35776.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_36160.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_35840.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12640_y_35904.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_35968.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_35968.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13344_y_36160.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_35968.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_35872.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_36032.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13312_y_36288.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_35776.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_35808.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13184_y_35872.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12672_y_35904.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13152_y_35968.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12448_y_35904.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12800_y_36256.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_36288.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_36448.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13024_y_36000.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_36448.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12448_y_36000.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12448_y_36032.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12672_y_35872.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12800_y_36320.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_36384.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13152_y_35904.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_36256.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16416.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16544.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16384_y_24352.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18624_y_16544.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16288_y_24448.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_17216.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_15968.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_17056.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18400_y_16608.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18368_y_16032.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16544_y_24768.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18464_y_16288.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_16672.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_16672.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17024_y_24512.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18784_y_16224.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17664_y_15488.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_17440.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17536_y_15552.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16384.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_15904.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16320.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_17472.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16864.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18368_y_16192.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17632_y_15104.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_17024.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_16704.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_15648.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_16416.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_16832.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16800.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_17216.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_15616.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_16288.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_15680.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_16480.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_16480.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_16896.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16320_y_25056.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18720_y_16448.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17344_y_28160.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_16704.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_16672.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16480.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_16928.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17664_y_15520.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16480_y_24992.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18368_y_16608.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_17536.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18368_y_16640.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_17216.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_15872.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19264_y_16352.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17088_y_24448.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18432_y_16192.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16768.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18720_y_16160.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16480.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17088_y_24512.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19360_y_16736.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_17024.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16160.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18432_y_16352.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_17088.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16160.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19264_y_16256.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16896.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16544.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18784_y_16672.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_16416.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16320.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_16640.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_17152.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_17248.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19264_y_16640.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_17024.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16320_y_24448.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18784_y_16800.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_16448.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16768_y_25152.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16736_y_25152.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16704_y_24640.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16064.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17600_y_15488.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_17376.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16608_y_24512.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16864_y_24288.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17120_y_25088.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16544_y_24800.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18304_y_16064.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17056_y_24512.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16096.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16384.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17280_y_24768.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16288.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_17056.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16768_y_24928.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16448.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_16320.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_17024.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_17184.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16480.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16288_y_25024.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16352.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_15648.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16672.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18720_y_16608.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16320.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16352_y_25024.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_17184.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_17056.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16544.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_16480.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16288.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_15712.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18720_y_16736.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16576.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_17152.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18624_y_16480.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_17216.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16544.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16800.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18432_y_15744.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16096.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16256_y_24448.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16480_y_24832.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16608_y_25024.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_16832.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16352.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18624_y_16448.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_15680.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16480_y_24800.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16992_y_24448.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18784_y_16768.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_16352.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16224.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_17280.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16832_y_24288.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_17408.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16192.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19264_y_16960.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16608_y_24608.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16000.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_17088.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17248_y_24768.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_16640.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16832.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_17504.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_16800.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17632_y_15520.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16832.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_16768.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_16992.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_17248.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_17504.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16256_y_24480.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_16000.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_16576.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16320_y_24608.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16416.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17280_y_28064.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16032.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17536_y_25120.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_17152.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_16192.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17408_y_24896.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19296_y_16608.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_16096.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18432_y_16640.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19264_y_16992.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19264_y_16704.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16864.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_17056.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18624_y_17120.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16768.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16352_y_24512.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_15648.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_16896.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18432_y_16704.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_17088.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16256.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19360_y_16672.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16768_y_25184.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18720_y_16640.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_16288.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_16608.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_17184.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16512.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16576.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16160_y_24864.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_16032.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16608_y_24640.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16416.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16640_y_25440.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17280_y_28192.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16800_y_25120.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_16608.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_17344.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16224.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16384.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_16832.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_16736.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18368_y_16672.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16416.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16608_y_24352.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_16480.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18400_y_16736.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17600_y_15392.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16320_y_24480.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_15968.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16640_y_24576.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_16736.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16992_y_24960.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16672_y_24512.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16512.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_16640.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_17120.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18336_y_16064.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_16096.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16128_y_24896.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16512.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_17440.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_17248.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16544_y_25024.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16064.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_17440.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_17536.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17440_y_24928.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_17216.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16864_y_24352.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16224.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_16768.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_17184.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_16064.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18368_y_16000.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_17344.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18720_y_15680.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16672_y_24928.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18464_y_16672.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17536_y_15520.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16704_y_24608.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16576.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16672.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16736.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18464_y_16608.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16544.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18784_y_16416.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17312_y_28128.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_17312.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_17120.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16448_y_24768.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_16128.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16640_y_24640.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16480_y_25024.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_17152.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_17120.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_17280.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_15584.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_15648.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17408_y_24928.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16832.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17344_y_28128.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_17120.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16736.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16512.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_17184.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_17440.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_16864.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18400_y_16672.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18784_y_16960.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18784_y_15968.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16512.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16800.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_15680.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_17024.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18720_y_16672.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16672.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_15616.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18784_y_16192.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18336_y_16160.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_17056.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16800.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16864.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17024_y_24960.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16640.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16320_y_24640.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17664_y_15328.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_17408.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_15712.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16736_y_24608.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_16704.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_16512.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16928.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_17088.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_16480.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16704.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_16960.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16960.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16096.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16448_y_25024.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_16384.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16832.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17376_y_24896.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16736.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_15648.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_16064.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_15584.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16192.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_16544.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_16864.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16768_y_25120.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16608.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_17248.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16384_y_24640.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_17184.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16512.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_16576.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16736.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_17088.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_17376.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_17088.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_16544.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_16064.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_17056.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19264_y_16480.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16512_y_24352.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17568_y_15488.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18464_y_16128.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16992.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16704.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17664_y_15456.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16480_y_25056.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_17088.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16256.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19296_y_16832.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_16768.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16928.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_17344.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_17120.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16608.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_16768.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16352_y_24608.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17664_y_15104.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16320.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16800_y_24928.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16672_y_24960.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17696_y_15488.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_15744.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_17344.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16704_y_24960.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17184_y_25120.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_16064.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16992_y_24480.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_15776.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16448.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18784_y_15744.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_16672.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_16096.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19328_y_16800.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17856_y_15296.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_16448.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18720_y_15712.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_15904.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16384_y_24512.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16992.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16512.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_16640.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16096.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_16736.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16512_y_24800.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_17216.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_17056.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16608_y_25120.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_17280.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_16576.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16768.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16416_y_24320.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_17376.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16576_y_24768.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18720_y_17440.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17312_y_28032.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16448.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_16992.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17568_y_25216.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16320.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_17120.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_16256.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19264_y_16512.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_17024.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17056_y_24480.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16832_y_24864.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16704.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18368_y_16736.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16288_y_24416.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_16704.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16608_y_24672.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16192.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19264_y_16768.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18720_y_16256.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16960.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_17472.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16320.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_15776.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18784_y_17440.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16192.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_17216.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16544.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19264_y_16576.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16064.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_17376.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16512.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17152_y_25088.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_16608.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_16032.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_16448.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_16544.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16576.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_16768.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_16224.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_17312.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16608.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_17056.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_16448.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19264_y_16384.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18624_y_16416.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_16416.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16576_y_24608.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_16512.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_17248.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16896.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_17056.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17248_y_24800.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_16928.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16896.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16416.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16000.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16640_y_24384.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_16416.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17632_y_15072.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_16480.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_16800.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16608.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_17408.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_16512.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17600_y_15136.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_16800.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_15936.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17600_y_15072.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_16000.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16480.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18624_y_16640.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_16640.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_16000.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16864_y_24320.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_17024.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_16672.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16800_y_25152.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16608.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16896_y_24576.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16480.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_17024.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_16672.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_17440.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16096.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16384.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_16384.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16352.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16672.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17664_y_15200.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_16448.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17120_y_25152.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17568_y_15392.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16416_y_24352.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16480.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16512_y_24320.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16096.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_17120.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16480_y_24352.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_16448.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_16960.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_16960.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_17248.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16672.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16576.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_16640.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16576_y_25024.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16032.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_16704.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_15648.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16448_y_24384.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16608_y_25088.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17568_y_25152.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_17216.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18336_y_16640.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16928.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_16576.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17664_y_15424.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16992.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16288.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_16448.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17664_y_15072.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16992.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18624_y_17280.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18400_y_16352.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16352.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_16512.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17696_y_15232.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_17152.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_16640.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_16288.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17664_y_15136.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_16576.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18464_y_15712.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16576.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_15712.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_16704.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18432_y_15680.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19296_y_16544.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_17248.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16544_y_25056.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18624_y_16832.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16288_y_24480.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16384.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18624_y_17248.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16128.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_17440.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17696_y_15520.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16512_y_25024.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16064.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18432_y_16128.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_17184.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_16416.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17280_y_24800.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18432_y_16256.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18624_y_16096.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19328_y_16832.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16384_y_24320.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_16736.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17600_y_25216.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_17248.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18720_y_16576.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_17280.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_16768.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_16672.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16960.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16768.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16640_y_25408.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16736_y_24640.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16288_y_25056.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_17152.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_16224.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_17152.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18720_y_16480.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_17024.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16928.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16736_y_24576.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16640_y_25056.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16512_y_24768.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16640.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16736.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_15680.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16896_y_24320.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_16864.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16896.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18464_y_16448.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17056_y_24448.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16704.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_16160.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16672_y_25088.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16864.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17600_y_15360.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_16032.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18368_y_16704.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16608.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_15936.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18464_y_17088.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17632_y_15200.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_15584.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19328_y_16768.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16736.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_16448.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16800.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_16416.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16288_y_25088.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16480_y_24768.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19264_y_16832.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_17408.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16704.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18464_y_16512.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16736.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_17312.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16288_y_24512.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_17408.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_17024.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16640_y_24608.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16128.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16672_y_24640.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17152_y_25120.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17024_y_24928.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16672.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17568_y_25120.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16384.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16512.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_17184.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16960.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16608_y_24384.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_16352.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_16160.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17632_y_15488.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_15648.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_17056.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_15744.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16320_y_25024.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16736_y_24672.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17632_y_15424.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_16800.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18464_y_16064.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16608_y_24992.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_17120.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16608.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16640.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19296_y_16736.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16896.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18304_y_16032.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17632_y_15136.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_17056.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_16128.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_16032.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18624_y_17408.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_16544.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16608.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18464_y_17024.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_16736.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17568_y_15360.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16640.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16416.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_17280.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_17088.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16544.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_15968.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_16288.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_17440.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16672_y_24608.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16448_y_24352.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18336_y_16032.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_16384.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19360_y_16768.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18624_y_17184.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16832.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16992.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_17152.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16160.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_17504.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_15936.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16608_y_24960.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_16672.png',\n", - " ...]" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" } ], "source": [ @@ -1779,56 +824,105 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ - "import copy\n", - "full_dataset_camelyon17 = copy.deepcopy(full_dataset)" + "full_dataset.metadata_fields\n", + "config = config_encode\n", + "#config_encode.groupby_fields\n", + "\n", + "train_grouper = CombinatorialGrouper(\n", + " dataset=full_dataset,\n", + " groupby_fields=config.groupby_fields)" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 20, "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "image_base None\n" - ] + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "supported.datasets[config_encode.dataset]\n", - "print(config_camelyon.train_transform, config_encode.train_transform)" + "config_encode.eval_splits" ] }, { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": {}, - "outputs": [], - "source": [] + "source": [ + "# Train/eval" + ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 15, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train data...\n", + " celltype = H1-hESC: n = 507309\n", + " celltype = HCT116: n = 506458\n", + " celltype = HeLa-S3: n = 509974\n", + " celltype = HepG2: n = 503007\n", + " celltype = K562: n = 506847\n", + " celltype = A549: n = 0\n", + " celltype = GM12878: n = 0\n", + "Validation (ID) data...\n", + " celltype = H1-hESC: n = 433473\n", + " celltype = HCT116: n = 431398\n", + " celltype = HeLa-S3: n = 435455\n", + " celltype = HepG2: n = 433039\n", + " celltype = K562: n = 430163\n", + " celltype = A549: n = 0\n", + " celltype = GM12878: n = 0\n", + "Test data...\n", + " celltype = H1-hESC: n = 0\n", + " celltype = HCT116: n = 0\n", + " celltype = HeLa-S3: n = 0\n", + " celltype = HepG2: n = 0\n", + " celltype = K562: n = 0\n", + " celltype = A549: n = 0\n", + " celltype = GM12878: n = 437124\n", + "Validation (OOD) data...\n", + " celltype = H1-hESC: n = 0\n", + " celltype = HCT116: n = 0\n", + " celltype = HeLa-S3: n = 0\n", + " celltype = HepG2: n = 0\n", + " celltype = K562: n = 0\n", + " celltype = A549: n = 433986\n", + " celltype = GM12878: n = 0\n" + ] + }, + { + "ename": "ValueError", + "evalue": "Model not recognized.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 60\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 61\u001b[0m \u001b[0;31m## Initialize algorithm\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 62\u001b[0;31m algorithm = initialize_algorithm(\n\u001b[0m\u001b[1;32m 63\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 64\u001b[0m \u001b[0mdatasets\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdatasets\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/wilds/examples/algorithms/initializer.py\u001b[0m in \u001b[0;36minitialize_algorithm\u001b[0;34m(config, datasets, train_grouper)\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0malgorithm\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m'ERM'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 32\u001b[0;31m algorithm = ERM(\n\u001b[0m\u001b[1;32m 33\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[0md_out\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0md_out\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/wilds/examples/algorithms/ERM.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, config, d_out, grouper, loss, metric, n_train_steps)\u001b[0m\n\u001b[1;32m 6\u001b[0m def __init__(self, config, d_out, grouper, loss,\n\u001b[1;32m 7\u001b[0m metric, n_train_steps):\n\u001b[0;32m----> 8\u001b[0;31m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minitialize_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0md_out\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 9\u001b[0m \u001b[0;31m# initialize module\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m super().__init__(\n", + "\u001b[0;32m~/wilds/examples/models/initializer.py\u001b[0m in \u001b[0;36minitialize_model\u001b[0;34m(config, d_out)\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mGINVirtual\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnum_tasks\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0md_out\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 30\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Model not recognized.'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 31\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mValueError\u001b[0m: Model not recognized." + ] + } + ], "source": [ - "train_grouper = CombinatorialGrouper(\n", - " dataset=full_dataset,\n", - " groupby_fields=config.groupby_fields)\n", - "\n", "datasets = defaultdict(dict)\n", "for split in full_dataset.split_dict.keys():\n", " if split=='train':\n", @@ -1898,81 +992,26 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']\n", - "val_celltype = ['A549']\n", - "test_celltype = ['GM12878']\n", - "all_celltypes = train_celltypes + val_celltype + test_celltype\n", - "\n", - "metadata_map = {}\n", - "metadata_map['chr'] = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX']\n", - "metadata_map['celltype'] = all_celltypes\n", - "\n", - "_split_dict = {\n", - " 'train': 0,\n", - " 'val-id': 1,\n", - " 'test': 2,\n", - " 'val-ood': 3\n", - "}\n", - "_split_names = {\n", - " 'train': 'Train',\n", - " 'val-id': 'Validation (ID)',\n", - " 'test': 'Test',\n", - " 'val-ood': 'Validation (OOD)'\n", - "}\n", - "_split_scheme = 'standard'" + "for " ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "('H1-hESC', 25.299736976623535)\n", - "('HCT116', 49.68733310699463)\n", - "('HeLa-S3', 74.65905213356018)\n", - "('HepG2', 99.33112812042236)\n", - "('K562', 124.1327919960022)\n", - "('A549', 149.19999814033508)\n", - "('GM12878', 174.0277030467987)\n" - ] - } - ], - "source": [ - "itime = time.time()\n", - "sequence_filename = os.path.join(data_dir, 'sequence.npz')\n", - "seq_arr = np.load(sequence_filename)\n", - "print(time.time() - itime)\n", - "\n", - "itime = time.time()\n", - "_seq_bp = {}\n", - "for chrom in seq_arr:\n", - " _seq_bp[chrom] = seq_arr[chrom]\n", - " print(chrom, time.time() - itime)\n", - "itime = time.time()\n", - "_dnase_allcelltypes = {}\n", - "for ct in all_celltypes:\n", - " dnase_filename = os.path.join(data_dir, '{}_dnase.npz'.format(ct))\n", - " dnase_npz_file = np.load(dnase_filename)\n", - " _dnase_allcelltypes[ct] = {}\n", - " for chrom in _seq_bp:\n", - " _dnase_allcelltypes[ct][chrom] = dnase_npz_file[chrom]\n", - " print(ct, time.time() - itime)" - ] + "outputs": [], + "source": [] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": {}, - "source": [ - "# Train/eval" - ] + "outputs": [], + "source": [] }, { "cell_type": "code", diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index f8b66f25..184da7cd 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -108,6 +108,20 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): metadata_df['y'] = y_array self._metadata_df = metadata_df[non_ambig_mask] + samp_ndces = [] + itime = time.time() + for ct in self._all_celltypes: + neg_msk = np.logical_and((self._metadata_df['celltype'] == ct), (self._metadata_df['y'] == 0)) + pos_msk = np.logical_and((self._metadata_df['celltype'] == ct), (self._metadata_df['y'] == 1)) + neg_ndces = np.where(neg_msk)[0] + pos_ndces = np.where(pos_msk)[0] + np.random.seed(42) + samp_neg_ndces = np.random.choice(neg_ndces, size=len(pos_ndces), replace=False) + samp_ndces.extend(samp_neg_ndces) + samp_ndces.extend(pos_ndces) + print(ct, time.time() - itime) + self._metadata_df = self._metadata_df.iloc[samp_ndces, :] + train_regions_mask = np.isin(self._metadata_df['chr'], self._train_chroms) val_regions_mask = np.isin(self._metadata_df['chr'], self._test_chroms) train_celltype_mask = np.isin(self._metadata_df['celltype'], self._train_celltypes) From e23f65323d86fc0fe247fc9a848b6811589df0ee Mon Sep 17 00:00:00 2001 From: aikanor Date: Fri, 26 Feb 2021 07:39:22 -0800 Subject: [PATCH 024/244] adding new architecture --- dataset_preprocessing/encode-tfbs/README.md | 1 + examples/configs/datasets.py | 4 +- examples/configs/model.py | 2 +- examples/models/CNN_genome.py | 127 ++- examples/models/initializer.py | 2 + examples/sbox_run_expt.ipynb | 1026 +++++++++++++++---- wilds/datasets/encodetfbs_dataset.py | 13 +- 7 files changed, 928 insertions(+), 247 deletions(-) diff --git a/dataset_preprocessing/encode-tfbs/README.md b/dataset_preprocessing/encode-tfbs/README.md index 616d4cb5..bf3f92c6 100644 --- a/dataset_preprocessing/encode-tfbs/README.md +++ b/dataset_preprocessing/encode-tfbs/README.md @@ -16,3 +16,4 @@ 5. Download the labels from the challenge into a label directory created for this purpose: - The training labels from https://www.synapse.org/#!Synapse:syn7413983 for the relevant transcription factor (e.g. https://www.synapse.org/#!Synapse:syn7415202 for the TF MAX). - The validation labels from https://www.synapse.org/#!Synapse:syn8441154 for the relevant transcription factor (e.g. https://www.synapse.org/#!Synapse:syn8442103 for the TF MAX). + - (Optional) The validation labels for the challenge's evaluation cell type from https://www.synapse.org/#!Synapse:syn8442975 for the relevant transcription factor (generally primary liver cells, e.g. https://www.synapse.org/#!Synapse:syn8443021 for the TF MAX). diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index 4bd331a0..c4b900fd 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -91,7 +91,7 @@ }, 'encode-tfbs': { 'split_scheme': 'official', - 'model': 'beagle', + 'model': 'leopard', 'model_kwargs': {'pretrained': False}, 'train_transform': None, 'eval_transform': None, @@ -105,7 +105,7 @@ 'batch_size': 64, 'lr': 0.001, 'weight_decay': 0.01, - 'n_epochs': 1, + 'n_epochs': 5, 'n_groups_per_batch': 2, 'algo_log_metric': 'accuracy', # 'irm_lambda': 1.0, diff --git a/examples/configs/model.py b/examples/configs/model.py index 31539d03..6d4c88c3 100644 --- a/examples/configs/model.py +++ b/examples/configs/model.py @@ -27,5 +27,5 @@ 'target_resolution': (224, 224), }, 'logistic_regression': {}, - 'beagle': {}, + 'leopard': {}, } diff --git a/examples/models/CNN_genome.py b/examples/models/CNN_genome.py index 1c65b567..3efdba20 100644 --- a/examples/models/CNN_genome.py +++ b/examples/models/CNN_genome.py @@ -4,51 +4,82 @@ import torch.nn as nn import torch.nn.functional as F -class Beagle(nn.Module): - """ - Neural net models over genomic sequence. Adapted from https://github.com/kundajelab/ChromDragoNN - - Input: - - s (Tensor): float torch tensor of shape (N, 5, 1000, 1) with batch size N. - - Output: - - prediction (Tensor): float torch tensor of shape (N, ) - """ - def __init__(self): - super(Beagle, self).__init__() - - self.dropout = 0.3 - self.num_cell_types = 1 - self.conv1 = nn.Conv2d(5, 300, (19, 1), stride = (1, 1), padding=(9,0)) - self.conv2 = nn.Conv2d(300, 200, (11, 1), stride = (1, 1), padding = (5,0)) - self.conv3 = nn.Conv2d(200, 200, (7, 1), stride = (1, 1), padding = (4,0)) - self.bn1 = nn.BatchNorm2d(300) - self.bn2 = nn.BatchNorm2d(200) - self.bn3 = nn.BatchNorm2d(200) - self.maxpool1 = nn.MaxPool2d((3, 1)) - self.maxpool2 = nn.MaxPool2d((4, 1)) - self.maxpool3 = nn.MaxPool2d((4, 1)) - - self.fc1 = nn.Linear(4200, 1000) - self.bn4 = nn.BatchNorm1d(1000) - - self.fc2 = nn.Linear(1000, 1000) - self.bn5 = nn.BatchNorm1d(1000) - - self.fc3 = nn.Linear(1000, self.num_cell_types) - - def forward(self, s): - s = s.permute(0, 2, 1).contiguous() # batch_size x 5 x 1000 - s = s.view(-1, 5, 1000, 1) # batch_size x 5 x 1000 x 1 [5 channels] - s = self.maxpool1(F.relu(self.bn1(self.conv1(s)))) # batch_size x 300 x 333 x 1 - s = self.maxpool2(F.relu(self.bn2(self.conv2(s)))) # batch_size x 200 x 83 x 1 - s = self.maxpool3(F.relu(self.bn3(self.conv3(s)))) # batch_size x 200 x 21 x 1 - s = s.view(-1, 4200) - conv_out = s - - s = F.dropout(F.relu(self.bn4(self.fc1(s))), p=self.dropout, training=self.training) # batch_size x 1000 - s = F.dropout(F.relu(self.bn5(self.fc2(s))), p=self.dropout, training=self.training) # batch_size x 1000 - - prediction = self.fc3(s) - - return s #, conv_out + + +def double_conv(in_channels, out_channels): + return nn.Sequential( + nn.Conv1d(in_channels, out_channels, 7, padding=3), + nn.BatchNorm1d(out_channels), + nn.ReLU(inplace=True), + nn.Conv1d(out_channels, out_channels, 7, padding=3), + nn.BatchNorm1d(out_channels), + nn.ReLU(inplace=True) + ) + + +class UNet(nn.Module): + + def __init__(self, n_class): + super().__init__() + + self.dconv_down1 = double_conv(6, 15) + self.dconv_down2 = double_conv(15, 22) + self.dconv_down3 = double_conv(22, 33) + self.dconv_down4 = double_conv(33, 49) + self.dconv_down5 = double_conv(49, 73) + self.dconv_down6 = double_conv(73, 109) + + self.maxpool = nn.MaxPool1d(2) + self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + + self.dconv_up5 = double_conv(73 + 109, 73) + self.dconv_up4 = double_conv(49 + 73, 49) + self.dconv_up3 = double_conv(33 + 49, 33) + self.dconv_up2 = double_conv(22 + 33, 22) + self.dconv_up1 = double_conv(15 + 22, 15) + + self.conv_last = nn.Conv2d(15, n_class, 1) + + + def forward(self, x): + conv1 = self.dconv_down1(x) + x = self.maxpool(conv1) + + conv2 = self.dconv_down2(x) + x = self.maxpool(conv2) + + conv3 = self.dconv_down3(x) + x = self.maxpool(conv3) + + conv4 = self.dconv_down4(x) + x = self.maxpool(conv4) + + conv5 = self.dconv_down5(x) + x = self.maxpool(conv5) + + x = self.dconv_down6(x) + + x = self.upsample(x) + x = torch.cat([x, conv5], dim=1) + + x = self.dconv_up5(x) + x = self.upsample(x) + x = torch.cat([x, conv4], dim=1) + + x = self.dconv_up4(x) + x = self.upsample(x) + x = torch.cat([x, conv3], dim=1) + + x = self.dconv_up3(x) + x = self.upsample(x) + x = torch.cat([x, conv2], dim=1) + + x = self.dconv_up2(x) + x = self.upsample(x) + x = torch.cat([x, conv1], dim=1) + + x = self.dconv_up1(x) + + out = self.conv_last(x) + + return out diff --git a/examples/models/initializer.py b/examples/models/initializer.py index fb77a5ea..5de63c54 100644 --- a/examples/models/initializer.py +++ b/examples/models/initializer.py @@ -27,6 +27,8 @@ def initialize_model(config, d_out): model = nn.Linear(out_features=d_out, **config.model_kwargs) elif config.model == 'gin-virtual': model = GINVirtual(num_tasks=d_out, **config.model_kwargs) + # elif config.model == 'leopard': + # model = GINVirtual(num_tasks=d_out, **config.model_kwargs) else: raise ValueError('Model not recognized.') return model diff --git a/examples/sbox_run_expt.ipynb b/examples/sbox_run_expt.ipynb index e50f790b..2c56cdd6 100644 --- a/examples/sbox_run_expt.ipynb +++ b/examples/sbox_run_expt.ipynb @@ -11,7 +11,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 123, "metadata": {}, "outputs": [ { @@ -21,7 +21,7 @@ "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpsutil\u001b[0m\u001b[0;34m;\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpsutil\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mProcess\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgetpid\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmemory_info\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrss\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0;36m1024\u001b[0m \u001b[0;34m**\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpsutil\u001b[0m\u001b[0;34m;\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpsutil\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mProcess\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgetpid\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmemory_info\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrss\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0;36m1024\u001b[0m \u001b[0;34m**\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'psutil'" ] } @@ -32,9 +32,17 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:root:The OGB package is out of date. Your version is 1.2.4, while the latest version is 1.2.5.\n" + ] + } + ], "source": [ "import os, csv\n", "import time\n", @@ -59,16 +67,16 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" + "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" ] }, - "execution_count": 3, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } @@ -163,7 +171,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -175,12 +183,13 @@ "config_encode = parser.parse_args(argstr_encode.split())\n", "config_encode = populate_defaults(config_encode)\n", "\n", - "config = config_camelyon" + "config = config_camelyon\n", + "# config = config_encode" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -292,7 +301,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -312,30 +321,26 @@ }, { "cell_type": "code", - "execution_count": 30, - "metadata": { - "jupyter": { - "source_hidden": true - } - }, + "execution_count": 7, + "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "chr2 3.657395362854004\n", - "chr9 5.770605564117432\n", - "chr11 7.801896095275879\n", - "chr1 11.56663990020752\n", - "chr8 13.764073133468628\n", - "chr21 14.483267068862915\n", - "H1-hESC 20.850953817367554\n", - "HCT116 27.05355429649353\n", - "HeLa-S3 33.51919412612915\n", - "HepG2 39.89570116996765\n", - "K562 46.36982774734497\n", - "A549 52.82617139816284\n", - "GM12878 59.167165994644165\n" + "chr2 3.7666022777557373\n", + "chr9 5.9439966678619385\n", + "chr11 8.030796766281128\n", + "chr1 11.851332426071167\n", + "chr8 14.106642007827759\n", + "chr21 14.852506160736084\n", + "H1-hESC 14.853845119476318\n", + "HCT116 14.853914022445679\n", + "HeLa-S3 14.853951930999756\n", + "HepG2 14.853987216949463\n", + "K562 14.854026317596436\n", + "A549 14.8540620803833\n", + "GM12878 14.854098796844482\n" ] } ], @@ -371,7 +376,7 @@ "_all_celltypes = _train_celltypes + _val_celltype + _test_celltype\n", "\n", "_metadata_map = {}\n", - "_metadata_map['chr'] = _all_chroms #['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX']\n", + "_metadata_map['chr'] = _all_chroms\n", "_metadata_map['celltype'] = _all_celltypes\n", "\n", "# Get the splits\n", @@ -402,12 +407,35 @@ "\n", "_dnase_allcelltypes = {}\n", "for ct in _all_celltypes:\n", + " \"\"\"\n", " dnase_filename = os.path.join(_data_dir, '{}_dnase.npz'.format(ct))\n", " dnase_npz_contents = np.load(dnase_filename)\n", " _dnase_allcelltypes[ct] = {}\n", " for chrom in _all_chroms: #_seq_bp:\n", " _dnase_allcelltypes[ct][chrom] = dnase_npz_contents[chrom]\n", - " print(ct, time.time() - itime)\n", + " \"\"\"\n", + " _dnase_allcelltypes[ct] = 'DNASE.{}.fc.signal.bigwig'\n", + " print(ct, time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\"\\nfilter_msk = all_df['start'] >= 0\\nfilter_msk = all_df['start']%1000 == 0\\nall_df = all_df[filter_msk]\\n\"" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "itime = time.time()\n", "\n", "# Read in metadata dataframe from training+validation data\n", "train_regions_labeled = pd.read_csv(os.path.join(_data_dir, 'labels/{}.train.labels.tsv.gz'.format(_transcription_factor)), sep='\\t')\n", @@ -416,36 +444,384 @@ "val_df = val_regions_labeled[np.isin(val_regions_labeled['chr'], _test_chroms)]\n", "all_df = pd.concat([training_df, val_df])\n", "\n", - "# Filter by start/stop coordinate if needed (TODO: remove for final version)\n", - "# filter_msk = all_df['start'] >= 0\n", - "# filter_msk = all_df['start']%1000 == 0\n", - "# all_df = all_df[filter_msk]\n", + "print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "# train_regions_labeled.replace({'U': 0, 'B': 1, 'A': -1})\n", + "# a = \n", + "# np.random.choice(train_regions_labeled.shape[0], size=100000)\n", + "\n", + "v = val_regions_labeled.replace({'U': 0, 'B': 1, 'A': -1})\n", + "# seta = [full_dataset_encode.get_input(x) for x in a]\n", + "# seta[0].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(array([-7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7]),\n", + " array([ 189, 854, 3579, 11535, 35901, 126629, 621676,\n", + " 7944663, 67689, 13516, 6766, 3332, 3179, 1076,\n", + " 2427]))" + ] + }, + "execution_count": 40, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.unique(v[['A549', 'GM12878', 'H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']].sum(axis=1), return_counts=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + ":12: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " tc_chr['y'] = y_array\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "11.363114833831787\n", + "21.872379302978516\n", + "32.51760506629944\n", + "42.88175559043884\n", + "53.35902285575867\n", + "63.94557332992554\n", + "74.44822382926941\n", + "92.237633228302\n" + ] + } + ], + "source": [ + "itime = time.time()\n", "\n", + "# Get the y values, and remove ambiguous labels by default.\n", "pd_list = []\n", "for ct in _all_celltypes:\n", " tc_chr = all_df[['chr', 'start', 'stop', ct]]\n", " tc_chr.columns = ['chr', 'start', 'stop', 'y']\n", + " y_array = tc_chr['y'].replace({'U': 0, 'B': 1, 'A': -1}).values\n", + " \n", + " # Now filter out ambiguous labels\n", + " non_ambig_mask = (y_array != -1)\n", + " tc_chr['y'] = y_array\n", + " tc_chr = tc_chr[non_ambig_mask]\n", + " \n", " tc_chr.insert(len(tc_chr.columns), 'celltype', ct)\n", " pd_list.append(tc_chr)\n", + " print(time.time() - itime)\n", "metadata_df = pd.concat(pd_list)\n", "\n", - "# Get the y values, and remove ambiguous labels by default.\n", - "y_array = metadata_df['y'].replace({'U': 0, 'B': 1, 'A': -1}).values\n", - "non_ambig_mask = (y_array != -1)\n", - "metadata_df['y'] = y_array\n", - "_metadata_df = metadata_df[non_ambig_mask]" + "print(time.time() - itime)\n", + "\n", + "# y_array = metadata_df['y'].replace({'U': 0, 'B': 1, 'A': -1}).values\n", + "# non_ambig_mask = (y_array != -1)\n", + "# metadata_df['y'] = y_array\n", + "# _metadata_df = metadata_df[non_ambig_mask]\n", + "\n", + "# print(time.time() - itime)" ] }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 75, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
chrstartstopycelltype
2702470chr116008000H1-hESC
2702471chr116508500H1-hESC
2702472chr117009000H1-hESC
2702473chr117509500H1-hESC
2702474chr1180010000H1-hESC
..................
8843006chr81463632001463634000GM12878
8843007chr81463632501463634500GM12878
8843008chr81463633001463635000GM12878
8843009chr81463633501463635500GM12878
8843010chr81463634001463636000GM12878
\n", + "

131721055 rows × 5 columns

\n", + "
" + ], + "text/plain": [ + " chr start stop y celltype\n", + "2702470 chr11 600 800 0 H1-hESC\n", + "2702471 chr11 650 850 0 H1-hESC\n", + "2702472 chr11 700 900 0 H1-hESC\n", + "2702473 chr11 750 950 0 H1-hESC\n", + "2702474 chr11 800 1000 0 H1-hESC\n", + "... ... ... ... .. ...\n", + "8843006 chr8 146363200 146363400 0 GM12878\n", + "8843007 chr8 146363250 146363450 0 GM12878\n", + "8843008 chr8 146363300 146363500 0 GM12878\n", + "8843009 chr8 146363350 146363550 0 GM12878\n", + "8843010 chr8 146363400 146363600 0 GM12878\n", + "\n", + "[131721055 rows x 5 columns]" + ] + }, + "execution_count": 75, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "metadata_df\n", + "# tc_chr" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [], + "source": [ + "# window_size = 12800\n", + "# window_interval = window_size/2\n", + "# trl_mask = (train_regions_labeled['start']%window_interval == 0)\n", + "# train_regions_labeled[trl_mask]" + ] + }, + { + "cell_type": "code", + "execution_count": 68, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "686900" + ] + }, + "execution_count": 68, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "(metadata_df['y'] == 1).sum()\n", + "# pd_list[0][non_ambig_mask]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# celltype_msk.sum()\n", + "\n", + "np.unique(_metadata_df['chr'])\n", + "\n", + "# celltype_msk = (_metadata_df['celltype'] == ct)\n", + "# np.where(celltype_msk)[0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "_metadata_df" + ] + }, + { + "cell_type": "code", + "execution_count": 24, "metadata": { + "collapsed": true, "jupyter": { - "source_hidden": true + "outputs_hidden": true } }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "H1-hESC 32.5968804359436\n", + "H1-hESC 33.237690687179565\n", + "H1-hESC 37.01208806037903\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mpos_msk\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0m_metadata_df\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'y'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mct\u001b[0m \u001b[0;32min\u001b[0m \u001b[0m_all_celltypes\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0mcelltype_msk\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0m_metadata_df\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'celltype'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mct\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 8\u001b[0m \u001b[0mneg_ct_msk\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlogical_and\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcelltype_msk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mneg_msk\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0mpos_ct_msk\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlogical_and\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcelltype_msk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpos_msk\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/pandas/core/ops/common.py\u001b[0m in \u001b[0;36mnew_method\u001b[0;34m(self, other)\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0mother\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mitem_from_zerodim\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mother\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 64\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 65\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mmethod\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mother\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 66\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 67\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mnew_method\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/pandas/core/arraylike.py\u001b[0m in \u001b[0;36m__eq__\u001b[0;34m(self, other)\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0munpack_zerodim_and_defer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"__eq__\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__eq__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mother\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 29\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_cmp_method\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mother\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moperator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meq\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 30\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0munpack_zerodim_and_defer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"__ne__\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/pandas/core/series.py\u001b[0m in \u001b[0;36m_cmp_method\u001b[0;34m(self, other, op)\u001b[0m\n\u001b[1;32m 4946\u001b[0m \u001b[0mrvalues\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mextract_array\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mother\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mextract_numpy\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4947\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 4948\u001b[0;31m \u001b[0mres_values\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mops\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcomparison_op\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlvalues\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrvalues\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mop\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4949\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4950\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_construct_result\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mres_values\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mres_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/pandas/core/ops/array_ops.py\u001b[0m in \u001b[0;36mcomparison_op\u001b[0;34m(left, right, op)\u001b[0m\n\u001b[1;32m 241\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 242\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mis_object_dtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlvalues\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 243\u001b[0;31m \u001b[0mres_values\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcomp_method_OBJECT_ARRAY\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mop\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlvalues\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrvalues\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 244\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 245\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/pandas/core/ops/array_ops.py\u001b[0m in \u001b[0;36mcomp_method_OBJECT_ARRAY\u001b[0;34m(op, x, y)\u001b[0m\n\u001b[1;32m 53\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlibops\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvec_compare\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mravel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mravel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mop\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 55\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlibops\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscalar_compare\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mravel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mop\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 56\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 57\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "# Downsample negatives to balance each celltype\n", + "samp_ndces = []\n", + "itime = time.time()\n", + "neg_msk = (_metadata_df['y'] == 0)\n", + "pos_msk = (_metadata_df['y'] == 1)\n", + "for ct in _all_celltypes:\n", + " celltype_msk = (_metadata_df['celltype'] == ct)\n", + " neg_ct_msk = np.logical_and(celltype_msk, neg_msk)\n", + " pos_ct_msk = np.logical_and(celltype_msk, pos_msk)\n", + " print(ct, time.time() - itime)\n", + " neg_ndces = np.where(neg_ct_msk)[0]\n", + " pos_ndces = np.where(pos_ct_msk)[0]\n", + " print(ct, time.time() - itime)\n", + " np.random.seed(42)\n", + " samp_neg_ndces = np.random.choice(neg_ndces, size=len(pos_ndces), replace=False)\n", + " samp_ndces.extend(samp_neg_ndces)\n", + " samp_ndces.extend(pos_ndces)\n", + " print(ct, time.time() - itime)\n", + "_metadata_df = _metadata_df.iloc[samp_ndces, :]" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, "outputs": [], "source": [ + "# Downsample negatives to balance each celltype\n", "samp_ndces = []\n", "itime = time.time()\n", "for ct in _all_celltypes:\n", @@ -458,7 +834,6 @@ " samp_ndces.extend(samp_neg_ndces)\n", " samp_ndces.extend(pos_ndces)\n", " print(ct, time.time() - itime)\n", - "\n", "_metadata_df = _metadata_df.iloc[samp_ndces, :]\n", "\n", "train_regions_mask = np.isin(_metadata_df['chr'], _train_chroms)\n", @@ -503,8 +878,12 @@ }, { "cell_type": "code", - "execution_count": 23, - "metadata": {}, + "execution_count": 19, + "metadata": { + "jupyter": { + "source_hidden": true + } + }, "outputs": [], "source": [ "import os, time\n", @@ -542,8 +921,8 @@ " self._y_size = 1\n", " self._n_classes = 2\n", " \n", - " self._train_chroms = ['chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX']\n", - " # self._train_chroms = ['chr2', 'chr9', 'chr11']\n", + " # self._train_chroms = ['chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX']\n", + " self._train_chroms = ['chr2', 'chr9', 'chr11']\n", " self._test_chroms = ['chr1', 'chr8', 'chr21']\n", " self._transcription_factor = 'MAX'\n", " self._train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']\n", @@ -553,7 +932,7 @@ " self._all_celltypes = self._train_celltypes + self._val_celltype + self._test_celltype\n", " \n", " self._metadata_map = {}\n", - " self._metadata_map['chr'] = self._all_chroms #['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX']\n", + " self._metadata_map['chr'] = self._all_chroms\n", " self._metadata_map['celltype'] = self._all_celltypes\n", " \n", " # Get the splits\n", @@ -696,43 +1075,38 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, + "execution_count": 20, + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + } + }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "chr2 3.718320846557617\n", - "chr3 6.73882269859314\n", - "chr4 9.651247501373291\n", - "chr5 12.439628839492798\n", - "chr6 15.05026388168335\n", - "chr7 17.475954055786133\n", - "chr9 19.6206693649292\n", - "chr10 21.68758535385132\n", - "chr11 23.74817419052124\n", - "chr12 25.81403160095215\n", - "chr13 27.559557676315308\n", - "chr14 29.18643832206726\n", - "chr15 30.739391565322876\n", - "chr16 32.11144256591797\n", - "chr17 33.348127126693726\n", - "chr18 34.53834342956543\n", - "chr19 35.434733629226685\n", - "chr20 36.399296283721924\n", - "chr22 37.1924102306366\n", - "chrX 39.56284308433533\n", - "chr1 43.3526566028595\n", - "chr8 45.583492040634155\n", - "chr21 46.311339378356934\n", - "H1-hESC 66.45292735099792\n", - "HCT116 86.06067085266113\n", - "HeLa-S3 106.47142815589905\n", - "HepG2 126.59437656402588\n", - "K562 146.93650436401367\n", - "A549 167.19306707382202\n", - "GM12878 187.4349775314331\n" + "chr2 3.7390823364257812\n", + "chr9 5.909312963485718\n", + "chr11 8.020122051239014\n", + "chr1 11.871179103851318\n", + "chr8 14.147786140441895\n", + "chr21 14.896430492401123\n", + "H1-hESC 21.391544818878174\n", + "HCT116 27.753155946731567\n", + "HeLa-S3 34.33590316772461\n", + "HepG2 40.81141257286072\n", + "K562 47.39495897293091\n", + "A549 54.245203495025635\n", + "GM12878 60.693068742752075\n", + "H1-hESC 16.79085922241211\n", + "HCT116 33.788668394088745\n", + "HeLa-S3 51.1968936920166\n", + "HepG2 68.32299137115479\n", + "K562 85.74746584892273\n", + "A549 103.05137896537781\n", + "GM12878 120.52022075653076\n" ] } ], @@ -746,24 +1120,33 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 39, + "metadata": {}, + "outputs": [], + "source": [ + "a = np.random.choice(1210796, size=128)\n", + "seta = [full_dataset_encode.get_input(x) for x in a]\n", + "seta[0].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "(array(['A549', 'GM12878', 'H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562'],\n", - " dtype=object),\n", - " array([ 5118, 1702, 8460, 12806, 8348, 11774, 12518]))" + "(array([0, 1, 2, 3]), array([2804551, 498433, 34145, 100851]))" ] }, - "execution_count": 20, + "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "np.unique(full_dataset_encode._metadata_df['celltype'], return_counts=True)" + "np.unique(full_dataset_encode._metadata_df['split'], return_counts=True)" ] }, { @@ -810,26 +1193,20 @@ "print(np.unique(full_dataset.y_array.numpy(), return_counts=True))\n", "print(np.unique(full_dataset._metadata_df['split'], return_counts=True))\n", "\n", - "#full_dataset._input_array" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# pd.read_csv(os.path.join('data/camelyon17_v1.0/metadata.csv'), index_col=0, dtype={'patient': 'str'})" + "#full_dataset._input_array\n", + "\n", + "#full_dataset_encode._seq_bp['chr11'].shape\n", + "full_dataset_encode._dnase_allcelltypes['HCT116']['chr11'].shape" ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 68, "metadata": {}, "outputs": [], "source": [ "full_dataset.metadata_fields\n", - "config = config_encode\n", + "config = config_camelyon\n", "#config_encode.groupby_fields\n", "\n", "train_grouper = CombinatorialGrouper(\n", @@ -839,86 +1216,51 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 118, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[]" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "config_encode.eval_splits" + "# full_dataset = copy.deepcopy(full_dataset_encode)\n", + "full_dataset = copy.deepcopy(full_dataset_camelyon17)\n", + "# full_dataset_camelyon17.split_dict" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "markdown", "metadata": {}, "source": [ - "# Train/eval" + "# Initialize algorithm" ] }, { "cell_type": "code", - "execution_count": 15, - "metadata": {}, + "execution_count": 120, + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + } + }, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Train data...\n", - " celltype = H1-hESC: n = 507309\n", - " celltype = HCT116: n = 506458\n", - " celltype = HeLa-S3: n = 509974\n", - " celltype = HepG2: n = 503007\n", - " celltype = K562: n = 506847\n", - " celltype = A549: n = 0\n", - " celltype = GM12878: n = 0\n", - "Validation (ID) data...\n", - " celltype = H1-hESC: n = 433473\n", - " celltype = HCT116: n = 431398\n", - " celltype = HeLa-S3: n = 435455\n", - " celltype = HepG2: n = 433039\n", - " celltype = K562: n = 430163\n", - " celltype = A549: n = 0\n", - " celltype = GM12878: n = 0\n", - "Test data...\n", - " celltype = H1-hESC: n = 0\n", - " celltype = HCT116: n = 0\n", - " celltype = HeLa-S3: n = 0\n", - " celltype = HepG2: n = 0\n", - " celltype = K562: n = 0\n", - " celltype = A549: n = 0\n", - " celltype = GM12878: n = 437124\n", - "Validation (OOD) data...\n", - " celltype = H1-hESC: n = 0\n", - " celltype = HCT116: n = 0\n", - " celltype = HeLa-S3: n = 0\n", - " celltype = HepG2: n = 0\n", - " celltype = K562: n = 0\n", - " celltype = A549: n = 433986\n", - " celltype = GM12878: n = 0\n" - ] - }, { "ename": "ValueError", - "evalue": "Model not recognized.", + "evalue": "I/O operation on closed file", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 60\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 61\u001b[0m \u001b[0;31m## Initialize algorithm\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 62\u001b[0;31m algorithm = initialize_algorithm(\n\u001b[0m\u001b[1;32m 63\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 64\u001b[0m \u001b[0mdatasets\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdatasets\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/wilds/examples/algorithms/initializer.py\u001b[0m in \u001b[0;36minitialize_algorithm\u001b[0;34m(config, datasets, train_grouper)\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0malgorithm\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m'ERM'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 32\u001b[0;31m algorithm = ERM(\n\u001b[0m\u001b[1;32m 33\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[0md_out\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0md_out\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/wilds/examples/algorithms/ERM.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, config, d_out, grouper, loss, metric, n_train_steps)\u001b[0m\n\u001b[1;32m 6\u001b[0m def __init__(self, config, d_out, grouper, loss,\n\u001b[1;32m 7\u001b[0m metric, n_train_steps):\n\u001b[0;32m----> 8\u001b[0;31m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minitialize_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0md_out\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 9\u001b[0m \u001b[0;31m# initialize module\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m super().__init__(\n", - "\u001b[0;32m~/wilds/examples/models/initializer.py\u001b[0m in \u001b[0;36minitialize_model\u001b[0;34m(config, d_out)\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mGINVirtual\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnum_tasks\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0md_out\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 30\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Model not recognized.'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 31\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mValueError\u001b[0m: Model not recognized." + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 57\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[0mlog_grouper\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_grouper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 59\u001b[0;31m \u001b[0mlog_group_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdatasets\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlog_grouper\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlogger\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 60\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 61\u001b[0m \u001b[0;31m## Initialize algorithm\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/wilds/examples/utils.py\u001b[0m in \u001b[0;36mlog_group_data\u001b[0;34m(datasets, grouper, logger)\u001b[0m\n\u001b[1;32m 68\u001b[0m \u001b[0mname\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'name'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 69\u001b[0m \u001b[0mdataset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'dataset'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 70\u001b[0;31m \u001b[0mlogger\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwrite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'{name} data...\\n'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 71\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mgrouper\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[0mlogger\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwrite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf' n = {len(dataset)}\\n'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/wilds/examples/utils.py\u001b[0m in \u001b[0;36mwrite\u001b[0;34m(self, msg)\u001b[0m\n\u001b[1;32m 97\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 98\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mwrite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 99\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconsole\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwrite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 100\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfile\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 101\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfile\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwrite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/ipykernel/iostream.py\u001b[0m in \u001b[0;36mwrite\u001b[0;34m(self, string)\u001b[0m\n\u001b[1;32m 392\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 393\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpub_thread\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 394\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'I/O operation on closed file'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 395\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 396\u001b[0m \u001b[0;31m# Make sure that we're handling unicode\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mValueError\u001b[0m: I/O operation on closed file" ] } ], @@ -992,26 +1334,76 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 135, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "device(type='cuda', index=0)" + ] + }, + "execution_count": 135, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "for " + "algorithm.device\n", + "# datasets['train']['loader']" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 134, "metadata": {}, - "outputs": [], - "source": [] + "outputs": [ + { + "ename": "RuntimeError", + "evalue": "CUDA out of memory. Tried to allocate 14.00 MiB (GPU 0; 11.93 GiB total capacity; 10.94 GiB already allocated; 5.06 MiB free; 11.32 GiB reserved in total by PyTorch)", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# datasets['train']['dataset'].size()\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0malgorithm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0malgorithm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 725\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 728\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 729\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torchvision/models/densenet.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 190\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 191\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 192\u001b[0;31m \u001b[0mfeatures\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfeatures\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 193\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfeatures\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minplace\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 194\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madaptive_avg_pool2d\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 725\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 728\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 729\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/modules/container.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 115\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 117\u001b[0;31m \u001b[0minput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 118\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 119\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 725\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 728\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 729\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torchvision/models/densenet.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, init_features)\u001b[0m\n\u001b[1;32m 109\u001b[0m \u001b[0mfeatures\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0minit_features\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 110\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlayer\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 111\u001b[0;31m \u001b[0mnew_features\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlayer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfeatures\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 112\u001b[0m \u001b[0mfeatures\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnew_features\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfeatures\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 725\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 728\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 729\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torchvision/models/densenet.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 82\u001b[0m \u001b[0mbottleneck_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcall_checkpoint_bottleneck\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprev_features\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 83\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 84\u001b[0;31m \u001b[0mbottleneck_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbn_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprev_features\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 85\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[0mnew_features\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconv2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelu2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbottleneck_output\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torchvision/models/densenet.py\u001b[0m in \u001b[0;36mbn_function\u001b[0;34m(self, inputs)\u001b[0m\n\u001b[1;32m 39\u001b[0m \u001b[0;31m# type: (List[Tensor]) -> Tensor\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 40\u001b[0m \u001b[0mconcated_features\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 41\u001b[0;31m \u001b[0mbottleneck_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconv1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelu1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconcated_features\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# noqa: T484\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 42\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mbottleneck_output\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 725\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 728\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 729\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/modules/batchnorm.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 129\u001b[0m \u001b[0mused\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mnormalization\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m \u001b[0;32min\u001b[0m \u001b[0meval\u001b[0m \u001b[0mmode\u001b[0m \u001b[0mwhen\u001b[0m \u001b[0mbuffers\u001b[0m \u001b[0mare\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 130\u001b[0m \"\"\"\n\u001b[0;32m--> 131\u001b[0;31m return F.batch_norm(\n\u001b[0m\u001b[1;32m 132\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 133\u001b[0m \u001b[0;31m# If buffers are not to be tracked, ensure that they won't be updated\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36mbatch_norm\u001b[0;34m(input, running_mean, running_var, weight, bias, training, momentum, eps)\u001b[0m\n\u001b[1;32m 2054\u001b[0m \u001b[0m_verify_batch_size\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2055\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2056\u001b[0;31m return torch.batch_norm(\n\u001b[0m\u001b[1;32m 2057\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrunning_mean\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrunning_var\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2058\u001b[0m \u001b[0mtraining\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmomentum\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meps\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackends\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcudnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menabled\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mRuntimeError\u001b[0m: CUDA out of memory. Tried to allocate 14.00 MiB (GPU 0; 11.93 GiB total capacity; 10.94 GiB already allocated; 5.06 MiB free; 11.32 GiB reserved in total by PyTorch)" + ] + } + ], + "source": [ + "# datasets['train']['dataset'].size()\n", + "algorithm.model(x.to(algorithm.device))" + ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 131, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "for batch in datasets['train']['loader']:\n", + " x, y_true, metadata = batch\n", + " break" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Train" + ] }, { "cell_type": "code", @@ -1097,7 +1489,28 @@ }, { "cell_type": "code", - "execution_count": 78, + "execution_count": 126, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 126, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "for b in full_dataset:\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": 33, "metadata": {}, "outputs": [], "source": [ @@ -1130,12 +1543,12 @@ "\n", " self.dropout = 0.3\n", " self.num_cell_types = 1\n", - " self.conv1 = nn.Conv2d(5, 300, (19, 1), stride = (1, 1), padding=(9,0))\n", - " self.conv2 = nn.Conv2d(300, 200, (11, 1), stride = (1, 1), padding = (5,0))\n", - " self.conv3 = nn.Conv2d(200, 200, (7, 1), stride = (1, 1), padding = (4,0))\n", - " self.bn1 = nn.BatchNorm2d(300)\n", - " self.bn2 = nn.BatchNorm2d(200)\n", - " self.bn3 = nn.BatchNorm2d(200)\n", + " self.conv1 = nn.Conv2d(5, 50, (19, 1), stride = (1, 1), padding=(9,0))\n", + " self.conv2 = nn.Conv2d(50, 50, (11, 1), stride = (1, 1), padding = (5,0))\n", + " self.conv3 = nn.Conv2d(50, 50, (7, 1), stride = (1, 1), padding = (4,0))\n", + " self.bn1 = nn.BatchNorm2d(50)\n", + " self.bn2 = nn.BatchNorm2d(50)\n", + " self.bn3 = nn.BatchNorm2d(50)\n", " self.maxpool1 = nn.MaxPool2d((3, 1))\n", " self.maxpool2 = nn.MaxPool2d((4, 1))\n", " self.maxpool3 = nn.MaxPool2d((4, 1))\n", @@ -1167,29 +1580,242 @@ }, { "cell_type": "code", - "execution_count": 86, + "execution_count": 124, "metadata": {}, + "outputs": [], + "source": [ + "def double_conv(in_channels, out_channels): \n", + " return nn.Sequential(\n", + " nn.Conv1d(in_channels, out_channels, 7, padding=3), \n", + " nn.BatchNorm1d(out_channels), \n", + " nn.ReLU(inplace=True),\n", + " nn.Conv1d(out_channels, out_channels, 7, padding=3), \n", + " nn.BatchNorm1d(out_channels), \n", + " nn.ReLU(inplace=True)\n", + " )\n", + "\n", + "\n", + "class UNet(nn.Module):\n", + "\n", + " def __init__(self, n_class):\n", + " super().__init__()\n", + " \n", + " self.dconv_down1 = double_conv(6, 15)\n", + " self.dconv_down2 = double_conv(15, 22)\n", + " self.dconv_down3 = double_conv(22, 33)\n", + " self.dconv_down4 = double_conv(33, 49)\n", + " self.dconv_down5 = double_conv(49, 73)\n", + " self.dconv_down6 = double_conv(73, 109)\n", + "\n", + " self.maxpool = nn.MaxPool1d(2)\n", + " self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) \n", + " \n", + " self.dconv_up5 = double_conv(73 + 109, 73)\n", + " self.dconv_up4 = double_conv(49 + 73, 49)\n", + " self.dconv_up3 = double_conv(33 + 49, 33)\n", + " self.dconv_up2 = double_conv(22 + 33, 22)\n", + " self.dconv_up1 = double_conv(15 + 22, 15)\n", + " \n", + " self.conv_last = nn.Conv2d(15, n_class, 1)\n", + " \n", + " \n", + " def forward(self, x):\n", + " conv1 = self.dconv_down1(x)\n", + " x = self.maxpool(conv1)\n", + "\n", + " conv2 = self.dconv_down2(x)\n", + " x = self.maxpool(conv2)\n", + " \n", + " conv3 = self.dconv_down3(x)\n", + " x = self.maxpool(conv3)\n", + " \n", + " conv4 = self.dconv_down4(x)\n", + " x = self.maxpool(conv4)\n", + " \n", + " conv5 = self.dconv_down5(x)\n", + " x = self.maxpool(conv5)\n", + " \n", + " x = self.dconv_down6(x)\n", + " \n", + " x = self.upsample(x) \n", + " x = torch.cat([x, conv5], dim=1)\n", + " \n", + " x = self.dconv_up5(x)\n", + " x = self.upsample(x) \n", + " x = torch.cat([x, conv4], dim=1)\n", + " \n", + " x = self.dconv_up4(x)\n", + " x = self.upsample(x) \n", + " x = torch.cat([x, conv3], dim=1)\n", + " \n", + " x = self.dconv_up3(x)\n", + " x = self.upsample(x) \n", + " x = torch.cat([x, conv2], dim=1) \n", + "\n", + " x = self.dconv_up2(x)\n", + " x = self.upsample(x) \n", + " x = torch.cat([x, conv1], dim=1) \n", + " \n", + " x = self.dconv_up1(x)\n", + " \n", + " out = self.conv_last(x)\n", + " \n", + " return out" + ] + }, + { + "cell_type": "code", + "execution_count": 125, + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + } + }, "outputs": [ { "data": { "text/plain": [ - "[('nnet.0.weight', 33280),\n", - " ('nnet.0.bias', 320),\n", - " ('bdlstm.0.weight_ih_l0', 409600),\n", - " ('bdlstm.0.weight_hh_l0', 409600),\n", - " ('bdlstm.0.bias_ih_l0', 1280),\n", - " ('bdlstm.0.bias_hh_l0', 1280),\n", - " ('bdlstm.0.weight_ih_l0_reverse', 409600),\n", - " ('bdlstm.0.weight_hh_l0_reverse', 409600),\n", - " ('bdlstm.0.bias_ih_l0_reverse', 1280),\n", - " ('bdlstm.0.bias_hh_l0_reverse', 1280),\n", - " ('classifier.1.weight', 592000),\n", - " ('classifier.1.bias', 925),\n", - " ('classifier.3.weight', 4625),\n", - " ('classifier.3.bias', 5)]" + "[('dconv_down1.0.weight', 630),\n", + " ('dconv_down1.0.bias', 15),\n", + " ('dconv_down1.1.weight', 15),\n", + " ('dconv_down1.1.bias', 15),\n", + " ('dconv_down1.3.weight', 1575),\n", + " ('dconv_down1.3.bias', 15),\n", + " ('dconv_down1.4.weight', 15),\n", + " ('dconv_down1.4.bias', 15),\n", + " ('dconv_down2.0.weight', 2310),\n", + " ('dconv_down2.0.bias', 22),\n", + " ('dconv_down2.1.weight', 22),\n", + " ('dconv_down2.1.bias', 22),\n", + " ('dconv_down2.3.weight', 3388),\n", + " ('dconv_down2.3.bias', 22),\n", + " ('dconv_down2.4.weight', 22),\n", + " ('dconv_down2.4.bias', 22),\n", + " ('dconv_down3.0.weight', 5082),\n", + " ('dconv_down3.0.bias', 33),\n", + " ('dconv_down3.1.weight', 33),\n", + " ('dconv_down3.1.bias', 33),\n", + " ('dconv_down3.3.weight', 7623),\n", + " ('dconv_down3.3.bias', 33),\n", + " ('dconv_down3.4.weight', 33),\n", + " ('dconv_down3.4.bias', 33),\n", + " ('dconv_down4.0.weight', 11319),\n", + " ('dconv_down4.0.bias', 49),\n", + " ('dconv_down4.1.weight', 49),\n", + " ('dconv_down4.1.bias', 49),\n", + " ('dconv_down4.3.weight', 16807),\n", + " ('dconv_down4.3.bias', 49),\n", + " ('dconv_down4.4.weight', 49),\n", + " ('dconv_down4.4.bias', 49),\n", + " ('dconv_down5.0.weight', 25039),\n", + " ('dconv_down5.0.bias', 73),\n", + " ('dconv_down5.1.weight', 73),\n", + " ('dconv_down5.1.bias', 73),\n", + " ('dconv_down5.3.weight', 37303),\n", + " ('dconv_down5.3.bias', 73),\n", + " ('dconv_down5.4.weight', 73),\n", + " ('dconv_down5.4.bias', 73),\n", + " ('dconv_down6.0.weight', 55699),\n", + " ('dconv_down6.0.bias', 109),\n", + " ('dconv_down6.1.weight', 109),\n", + " ('dconv_down6.1.bias', 109),\n", + " ('dconv_down6.3.weight', 83167),\n", + " ('dconv_down6.3.bias', 109),\n", + " ('dconv_down6.4.weight', 109),\n", + " ('dconv_down6.4.bias', 109),\n", + " ('dconv_up5.0.weight', 93002),\n", + " ('dconv_up5.0.bias', 73),\n", + " ('dconv_up5.1.weight', 73),\n", + " ('dconv_up5.1.bias', 73),\n", + " ('dconv_up5.3.weight', 37303),\n", + " ('dconv_up5.3.bias', 73),\n", + " ('dconv_up5.4.weight', 73),\n", + " ('dconv_up5.4.bias', 73),\n", + " ('dconv_up4.0.weight', 41846),\n", + " ('dconv_up4.0.bias', 49),\n", + " ('dconv_up4.1.weight', 49),\n", + " ('dconv_up4.1.bias', 49),\n", + " ('dconv_up4.3.weight', 16807),\n", + " ('dconv_up4.3.bias', 49),\n", + " ('dconv_up4.4.weight', 49),\n", + " ('dconv_up4.4.bias', 49),\n", + " ('dconv_up3.0.weight', 18942),\n", + " ('dconv_up3.0.bias', 33),\n", + " ('dconv_up3.1.weight', 33),\n", + " ('dconv_up3.1.bias', 33),\n", + " ('dconv_up3.3.weight', 7623),\n", + " ('dconv_up3.3.bias', 33),\n", + " ('dconv_up3.4.weight', 33),\n", + " ('dconv_up3.4.bias', 33),\n", + " ('dconv_up2.0.weight', 8470),\n", + " ('dconv_up2.0.bias', 22),\n", + " ('dconv_up2.1.weight', 22),\n", + " ('dconv_up2.1.bias', 22),\n", + " ('dconv_up2.3.weight', 3388),\n", + " ('dconv_up2.3.bias', 22),\n", + " ('dconv_up2.4.weight', 22),\n", + " ('dconv_up2.4.bias', 22),\n", + " ('dconv_up1.0.weight', 3885),\n", + " ('dconv_up1.0.bias', 15),\n", + " ('dconv_up1.1.weight', 15),\n", + " ('dconv_up1.1.bias', 15),\n", + " ('dconv_up1.3.weight', 1575),\n", + " ('dconv_up1.3.bias', 15),\n", + " ('dconv_up1.4.weight', 15),\n", + " ('dconv_up1.4.bias', 15),\n", + " ('conv_last.weight', 30),\n", + " ('conv_last.bias', 2)]" ] }, - "execution_count": 86, + "execution_count": 125, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model = UNet(2)\n", + "#model = DanQ(50, 5)\n", + "\n", + "lst = [(x[0], x[1].numel()) for x in model.named_parameters()]\n", + "#np.sum([x[1] for x in lst])\n", + "count_parameters(model)\n", + "lst" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[('conv1.weight', 4750),\n", + " ('conv1.bias', 50),\n", + " ('conv2.weight', 27500),\n", + " ('conv2.bias', 50),\n", + " ('conv3.weight', 17500),\n", + " ('conv3.bias', 50),\n", + " ('bn1.weight', 50),\n", + " ('bn1.bias', 50),\n", + " ('bn2.weight', 50),\n", + " ('bn2.bias', 50),\n", + " ('bn3.weight', 50),\n", + " ('bn3.bias', 50),\n", + " ('fc1.weight', 4200000),\n", + " ('fc1.bias', 1000),\n", + " ('bn4.weight', 1000),\n", + " ('bn4.bias', 1000),\n", + " ('fc2.weight', 1000000),\n", + " ('fc2.bias', 1000),\n", + " ('bn5.weight', 1000),\n", + " ('bn5.bias', 1000),\n", + " ('fc3.weight', 1000),\n", + " ('fc3.bias', 1)]" + ] + }, + "execution_count": 34, "metadata": {}, "output_type": "execute_result" } @@ -1198,14 +1824,28 @@ "def count_parameters(model):\n", " return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", "\n", - "model = Beagle2()\n", - "model = DanQ(50, 5)\n", + "model = Beagle()\n", + "#model = DanQ(50, 5)\n", "\n", "lst = [(x[0], x[1].numel()) for x in model.named_parameters()]\n", "#np.sum([x[1] for x in lst])\n", "count_parameters(model)\n", "lst" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index 184da7cd..08cba281 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -6,6 +6,8 @@ from wilds.common.grouper import CombinatorialGrouper from wilds.common.metrics.all_metrics import Accuracy +all_chrom_names = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX'] + class EncodeTFBSDataset(WILDSDataset): """ ENCODE-DREAM-wilds dataset of transcription factor binding sites. @@ -33,8 +35,8 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): self._y_size = 1 self._n_classes = 2 - # self._train_chroms = ['chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX'] - self._train_chroms = ['chr2', 'chr9', 'chr11'] + self._train_chroms = ['chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX'] + # self._train_chroms = ['chr2', 'chr9', 'chr11'] self._test_chroms = ['chr1', 'chr8', 'chr21'] self._transcription_factor = 'MAX' self._train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562'] @@ -44,7 +46,7 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): self._all_celltypes = self._train_celltypes + self._val_celltype + self._test_celltype self._metadata_map = {} - self._metadata_map['chr'] = self._all_chroms #['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX'] + self._metadata_map['chr'] = self._all_chroms self._metadata_map['celltype'] = self._all_celltypes # Get the splits @@ -75,11 +77,14 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): self._dnase_allcelltypes = {} for ct in self._all_celltypes: + """ dnase_filename = os.path.join(self._data_dir, '{}_dnase.npz'.format(ct)) dnase_npz_contents = np.load(dnase_filename) self._dnase_allcelltypes[ct] = {} for chrom in self._all_chroms: #self._seq_bp: self._dnase_allcelltypes[ct][chrom] = dnase_npz_contents[chrom] + """ + self._dnase_allcelltypes[ct] = 'DNASE.{}.fc.signal.bigwig' print(ct, time.time() - itime) # Read in metadata dataframe from training+validation data @@ -90,9 +95,11 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): all_df = pd.concat([training_df, val_df]) # Filter by start/stop coordinate if needed (TODO: remove for final version) + """ filter_msk = all_df['start'] >= 0 filter_msk = all_df['start']%1000 == 0 all_df = all_df[filter_msk] + """ pd_list = [] for ct in self._all_celltypes: From a70f43b8de4cb7200d1891e0839bda6e44cd659b Mon Sep 17 00:00:00 2001 From: aikanor Date: Sun, 28 Feb 2021 07:24:05 -0800 Subject: [PATCH 025/244] model changes (TODO: eval check-in) --- examples/models/CNN_genome.py | 55 +- examples/sbox_run_expt.ipynb | 1052 +++++++++++++++------------------ 2 files changed, 503 insertions(+), 604 deletions(-) diff --git a/examples/models/CNN_genome.py b/examples/models/CNN_genome.py index 3efdba20..f1b90d07 100644 --- a/examples/models/CNN_genome.py +++ b/examples/models/CNN_genome.py @@ -42,43 +42,46 @@ def __init__(self, n_class): def forward(self, x): - conv1 = self.dconv_down1(x) - x = self.maxpool(conv1) + # input_size = 12800 + # input_channels = 6 + conv1 = self.dconv_down1(x) # Out: (input_size) x 15 + x = self.maxpool(conv1) # (input_size / 2) x 15 - conv2 = self.dconv_down2(x) - x = self.maxpool(conv2) + conv2 = self.dconv_down2(x) # (input_size / 2) x 22 + x = self.maxpool(conv2) # (input_size / 4) x 22 - conv3 = self.dconv_down3(x) - x = self.maxpool(conv3) + conv3 = self.dconv_down3(x) # (input_size / 4) x 33 + x = self.maxpool(conv3) # (input_size / 8) x 33 - conv4 = self.dconv_down4(x) - x = self.maxpool(conv4) + conv4 = self.dconv_down4(x) # (input_size / 8) x 49 + x = self.maxpool(conv4) # (input_size / 16) x 49 - conv5 = self.dconv_down5(x) - x = self.maxpool(conv5) + conv5 = self.dconv_down5(x) # (input_size / 16) x 73 + x = self.maxpool(conv5) # (input_size / 32) x 73 - x = self.dconv_down6(x) + conv6 = self.dconv_down6(x) # (input_size / 32) x 109 + # Encoder finished. - x = self.upsample(x) - x = torch.cat([x, conv5], dim=1) + x = self.upsample(conv6) # (input_size / 16) x 109 + x = torch.cat([x, conv5], dim=1) # (input_size / 16) x (109 + 73) - x = self.dconv_up5(x) - x = self.upsample(x) - x = torch.cat([x, conv4], dim=1) + x = self.dconv_up5(x) # (input_size / 16) x 73 + x = self.upsample(x) # (input_size / 8) x 73 + x = torch.cat([x, conv4], dim=1) # (input_size / 8) x (73 + 49) - x = self.dconv_up4(x) - x = self.upsample(x) - x = torch.cat([x, conv3], dim=1) + x = self.dconv_up4(x) # (input_size / 8) x 49 + x = self.upsample(x) # (input_size / 4) x 49 + x = torch.cat([x, conv3], dim=1) # (input_size / 4) x (49 + 33) - x = self.dconv_up3(x) - x = self.upsample(x) - x = torch.cat([x, conv2], dim=1) + x = self.dconv_up3(x) # (input_size / 4) x 33 + x = self.upsample(x) # (input_size / 2) x 33 + x = torch.cat([x, conv2], dim=1) # (input_size / 2) x (33 + 22) - x = self.dconv_up2(x) - x = self.upsample(x) - x = torch.cat([x, conv1], dim=1) + x = self.dconv_up2(x) # (input_size / 2) x 22 + x = self.upsample(x) # (input_size) x 22 + x = torch.cat([x, conv1], dim=1) # (input_size) x (22 + 15) - x = self.dconv_up1(x) + x = self.dconv_up1(x) # (input_size) x 15 out = self.conv_last(x) diff --git a/examples/sbox_run_expt.ipynb b/examples/sbox_run_expt.ipynb index 2c56cdd6..66712a29 100644 --- a/examples/sbox_run_expt.ipynb +++ b/examples/sbox_run_expt.ipynb @@ -34,15 +34,7 @@ "cell_type": "code", "execution_count": 1, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:root:The OGB package is out of date. Your version is 1.2.4, while the latest version is 1.2.5.\n" - ] - } - ], + "outputs": [], "source": [ "import os, csv\n", "import time\n", @@ -73,7 +65,7 @@ { "data": { "text/plain": [ - "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" + "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" ] }, "execution_count": 2, @@ -192,6 +184,13 @@ "execution_count": 4, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:root:The OGB package is out of date. Your version is 1.2.4, while the latest version is 1.2.5.\n" + ] + }, { "name": "stdout", "output_type": "stream", @@ -321,26 +320,26 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "chr2 3.7666022777557373\n", - "chr9 5.9439966678619385\n", - "chr11 8.030796766281128\n", - "chr1 11.851332426071167\n", - "chr8 14.106642007827759\n", - "chr21 14.852506160736084\n", - "H1-hESC 14.853845119476318\n", - "HCT116 14.853914022445679\n", - "HeLa-S3 14.853951930999756\n", - "HepG2 14.853987216949463\n", - "K562 14.854026317596436\n", - "A549 14.8540620803833\n", - "GM12878 14.854098796844482\n" + "chr2 3.764267683029175\n", + "chr9 5.914910078048706\n", + "chr11 7.964999675750732\n", + "chr1 11.748822927474976\n", + "chr8 14.01279878616333\n", + "chr21 14.737261772155762\n", + "H1-hESC 14.73790693283081\n", + "HCT116 14.737961292266846\n", + "HeLa-S3 14.737993240356445\n", + "HepG2 14.738024950027466\n", + "K562 14.73805570602417\n", + "A549 14.738086223602295\n", + "GM12878 14.738116979598999\n" ] } ], @@ -420,18 +419,15 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "\"\\nfilter_msk = all_df['start'] >= 0\\nfilter_msk = all_df['start']%1000 == 0\\nall_df = all_df[filter_msk]\\n\"" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "66.32568740844727\n" + ] } ], "source": [ @@ -447,44 +443,6 @@ "print(time.time() - itime)" ] }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [], - "source": [ - "# train_regions_labeled.replace({'U': 0, 'B': 1, 'A': -1})\n", - "# a = \n", - "# np.random.choice(train_regions_labeled.shape[0], size=100000)\n", - "\n", - "v = val_regions_labeled.replace({'U': 0, 'B': 1, 'A': -1})\n", - "# seta = [full_dataset_encode.get_input(x) for x in a]\n", - "# seta[0].shape" - ] - }, - { - "cell_type": "code", - "execution_count": 40, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(array([-7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7]),\n", - " array([ 189, 854, 3579, 11535, 35901, 126629, 621676,\n", - " 7944663, 67689, 13516, 6766, 3332, 3179, 1076,\n", - " 2427]))" - ] - }, - "execution_count": 40, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "np.unique(v[['A549', 'GM12878', 'H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']].sum(axis=1), return_counts=True)" - ] - }, { "cell_type": "code", "execution_count": 59, @@ -547,160 +505,6 @@ "# print(time.time() - itime)" ] }, - { - "cell_type": "code", - "execution_count": 75, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
chrstartstopycelltype
2702470chr116008000H1-hESC
2702471chr116508500H1-hESC
2702472chr117009000H1-hESC
2702473chr117509500H1-hESC
2702474chr1180010000H1-hESC
..................
8843006chr81463632001463634000GM12878
8843007chr81463632501463634500GM12878
8843008chr81463633001463635000GM12878
8843009chr81463633501463635500GM12878
8843010chr81463634001463636000GM12878
\n", - "

131721055 rows × 5 columns

\n", - "
" - ], - "text/plain": [ - " chr start stop y celltype\n", - "2702470 chr11 600 800 0 H1-hESC\n", - "2702471 chr11 650 850 0 H1-hESC\n", - "2702472 chr11 700 900 0 H1-hESC\n", - "2702473 chr11 750 950 0 H1-hESC\n", - "2702474 chr11 800 1000 0 H1-hESC\n", - "... ... ... ... .. ...\n", - "8843006 chr8 146363200 146363400 0 GM12878\n", - "8843007 chr8 146363250 146363450 0 GM12878\n", - "8843008 chr8 146363300 146363500 0 GM12878\n", - "8843009 chr8 146363350 146363550 0 GM12878\n", - "8843010 chr8 146363400 146363600 0 GM12878\n", - "\n", - "[131721055 rows x 5 columns]" - ] - }, - "execution_count": 75, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "metadata_df\n", - "# tc_chr" - ] - }, { "cell_type": "code", "execution_count": 42, @@ -715,7 +519,7 @@ }, { "cell_type": "code", - "execution_count": 68, + "execution_count": 108, "metadata": {}, "outputs": [ { @@ -724,7 +528,7 @@ "686900" ] }, - "execution_count": 68, + "execution_count": 108, "metadata": {}, "output_type": "execute_result" } @@ -736,60 +540,34 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# celltype_msk.sum()\n", - "\n", - "np.unique(_metadata_df['chr'])\n", - "\n", - "# celltype_msk = (_metadata_df['celltype'] == ct)\n", - "# np.where(celltype_msk)[0]" - ] - }, - { - "cell_type": "code", - "execution_count": null, + "execution_count": 88, "metadata": {}, - "outputs": [], - "source": [ - "_metadata_df" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": { - "collapsed": true, - "jupyter": { - "outputs_hidden": true - } - }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "H1-hESC 32.5968804359436\n", - "H1-hESC 33.237690687179565\n", - "H1-hESC 37.01208806037903\n" - ] - }, - { - "ename": "KeyboardInterrupt", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mpos_msk\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0m_metadata_df\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'y'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mct\u001b[0m \u001b[0;32min\u001b[0m \u001b[0m_all_celltypes\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0mcelltype_msk\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0m_metadata_df\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'celltype'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mct\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 8\u001b[0m \u001b[0mneg_ct_msk\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlogical_and\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcelltype_msk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mneg_msk\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0mpos_ct_msk\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlogical_and\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcelltype_msk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpos_msk\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/pandas/core/ops/common.py\u001b[0m in \u001b[0;36mnew_method\u001b[0;34m(self, other)\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0mother\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mitem_from_zerodim\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mother\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 64\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 65\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mmethod\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mother\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 66\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 67\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mnew_method\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/pandas/core/arraylike.py\u001b[0m in \u001b[0;36m__eq__\u001b[0;34m(self, other)\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0munpack_zerodim_and_defer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"__eq__\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__eq__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mother\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 29\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_cmp_method\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mother\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moperator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meq\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 30\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0munpack_zerodim_and_defer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"__ne__\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/pandas/core/series.py\u001b[0m in \u001b[0;36m_cmp_method\u001b[0;34m(self, other, op)\u001b[0m\n\u001b[1;32m 4946\u001b[0m \u001b[0mrvalues\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mextract_array\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mother\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mextract_numpy\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4947\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 4948\u001b[0;31m \u001b[0mres_values\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mops\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcomparison_op\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlvalues\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrvalues\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mop\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4949\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4950\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_construct_result\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mres_values\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mres_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/pandas/core/ops/array_ops.py\u001b[0m in \u001b[0;36mcomparison_op\u001b[0;34m(left, right, op)\u001b[0m\n\u001b[1;32m 241\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 242\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mis_object_dtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlvalues\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 243\u001b[0;31m \u001b[0mres_values\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcomp_method_OBJECT_ARRAY\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mop\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlvalues\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrvalues\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 244\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 245\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/pandas/core/ops/array_ops.py\u001b[0m in \u001b[0;36mcomp_method_OBJECT_ARRAY\u001b[0;34m(op, x, y)\u001b[0m\n\u001b[1;32m 53\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlibops\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvec_compare\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mravel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mravel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mop\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 55\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlibops\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscalar_compare\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mravel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mop\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 56\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 57\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + "H1-hESC 8.10781979560852\n", + "H1-hESC 8.47616195678711\n", + "H1-hESC 9.822284698486328\n", + "HCT116 17.048683881759644\n", + "HCT116 17.41142964363098\n", + "HCT116 18.752415657043457\n", + "HeLa-S3 26.464386463165283\n", + "HeLa-S3 26.860748291015625\n", + "HeLa-S3 28.151614665985107\n", + "HepG2 35.439460039138794\n", + "HepG2 35.83507966995239\n", + "HepG2 37.079824924468994\n", + "K562 44.71583318710327\n", + "K562 45.092923164367676\n", + "K562 46.389798402786255\n", + "A549 53.895429372787476\n", + "A549 54.27841639518738\n", + "A549 55.64506816864014\n", + "GM12878 63.17967939376831\n", + "GM12878 63.545384883880615\n", + "GM12878 64.84915113449097\n" ] } ], @@ -801,34 +579,12 @@ "pos_msk = (_metadata_df['y'] == 1)\n", "for ct in _all_celltypes:\n", " celltype_msk = (_metadata_df['celltype'] == ct)\n", + " print(ct, time.time() - itime)\n", " neg_ct_msk = np.logical_and(celltype_msk, neg_msk)\n", " pos_ct_msk = np.logical_and(celltype_msk, pos_msk)\n", " print(ct, time.time() - itime)\n", " neg_ndces = np.where(neg_ct_msk)[0]\n", " pos_ndces = np.where(pos_ct_msk)[0]\n", - " print(ct, time.time() - itime)\n", - " np.random.seed(42)\n", - " samp_neg_ndces = np.random.choice(neg_ndces, size=len(pos_ndces), replace=False)\n", - " samp_ndces.extend(samp_neg_ndces)\n", - " samp_ndces.extend(pos_ndces)\n", - " print(ct, time.time() - itime)\n", - "_metadata_df = _metadata_df.iloc[samp_ndces, :]" - ] - }, - { - "cell_type": "code", - "execution_count": 35, - "metadata": {}, - "outputs": [], - "source": [ - "# Downsample negatives to balance each celltype\n", - "samp_ndces = []\n", - "itime = time.time()\n", - "for ct in _all_celltypes:\n", - " neg_msk = np.logical_and((_metadata_df['celltype'] == ct), (_metadata_df['y'] == 0))\n", - " pos_msk = np.logical_and((_metadata_df['celltype'] == ct), (_metadata_df['y'] == 1))\n", - " neg_ndces = np.where(neg_msk)[0]\n", - " pos_ndces = np.where(pos_msk)[0]\n", " np.random.seed(42)\n", " samp_neg_ndces = np.random.choice(neg_ndces, size=len(pos_ndces), replace=False)\n", " samp_ndces.extend(samp_neg_ndces)\n", @@ -878,7 +634,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 106, "metadata": { "jupyter": { "source_hidden": true @@ -996,13 +752,19 @@ " metadata_df['y'] = y_array\n", " self._metadata_df = metadata_df[non_ambig_mask]\n", " \n", + " # Downsample negatives to balance each celltype\n", " samp_ndces = []\n", " itime = time.time()\n", - " for ct in self._all_celltypes:\n", - " neg_msk = np.logical_and((self._metadata_df['celltype'] == ct), (self._metadata_df['y'] == 0))\n", - " pos_msk = np.logical_and((self._metadata_df['celltype'] == ct), (self._metadata_df['y'] == 1))\n", - " neg_ndces = np.where(neg_msk)[0]\n", - " pos_ndces = np.where(pos_msk)[0]\n", + " neg_msk = (self._metadata_df['y'] == 0)\n", + " pos_msk = (self._metadata_df['y'] == 1)\n", + " for ct in _all_celltypes:\n", + " celltype_msk = (self._metadata_df['celltype'] == ct)\n", + " print(ct, time.time() - itime)\n", + " neg_ct_msk = np.logical_and(celltype_msk, neg_msk)\n", + " pos_ct_msk = np.logical_and(celltype_msk, pos_msk)\n", + " print(ct, time.time() - itime)\n", + " neg_ndces = np.where(neg_ct_msk)[0]\n", + " pos_ndces = np.where(pos_ct_msk)[0]\n", " np.random.seed(42)\n", " samp_neg_ndces = np.random.choice(neg_ndces, size=len(pos_ndces), replace=False)\n", " samp_ndces.extend(samp_neg_ndces)\n", @@ -1075,38 +837,47 @@ }, { "cell_type": "code", - "execution_count": 20, - "metadata": { - "collapsed": true, - "jupyter": { - "outputs_hidden": true - } - }, + "execution_count": 107, + "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "chr2 3.7390823364257812\n", - "chr9 5.909312963485718\n", - "chr11 8.020122051239014\n", - "chr1 11.871179103851318\n", - "chr8 14.147786140441895\n", - "chr21 14.896430492401123\n", - "H1-hESC 21.391544818878174\n", - "HCT116 27.753155946731567\n", - "HeLa-S3 34.33590316772461\n", - "HepG2 40.81141257286072\n", - "K562 47.39495897293091\n", - "A549 54.245203495025635\n", - "GM12878 60.693068742752075\n", - "H1-hESC 16.79085922241211\n", - "HCT116 33.788668394088745\n", - "HeLa-S3 51.1968936920166\n", - "HepG2 68.32299137115479\n", - "K562 85.74746584892273\n", - "A549 103.05137896537781\n", - "GM12878 120.52022075653076\n" + "chr2 3.962329387664795\n", + "chr9 6.259538888931274\n", + "chr11 8.446826934814453\n", + "chr1 12.49940538406372\n", + "chr8 14.91869592666626\n", + "chr21 15.700694799423218\n", + "H1-hESC 23.95099449157715\n", + "HCT116 31.26502823829651\n", + "HeLa-S3 39.382277488708496\n", + "HepG2 47.24500226974487\n", + "K562 55.079211711883545\n", + "A549 62.405343532562256\n", + "GM12878 70.00356984138489\n", + "H1-hESC 8.160386562347412\n", + "H1-hESC 8.546203374862671\n", + "H1-hESC 9.868412971496582\n", + "HCT116 17.121587991714478\n", + "HCT116 17.524660110473633\n", + "HCT116 18.90956425666809\n", + "HeLa-S3 26.98938488960266\n", + "HeLa-S3 27.376858234405518\n", + "HeLa-S3 28.7989022731781\n", + "HepG2 36.29348182678223\n", + "HepG2 36.668752908706665\n", + "HepG2 38.151512145996094\n", + "K562 45.96789216995239\n", + "K562 46.33995985984802\n", + "K562 47.87280249595642\n", + "A549 55.380892276763916\n", + "A549 55.75924301147461\n", + "A549 57.22686314582825\n", + "GM12878 65.09361720085144\n", + "GM12878 65.50619888305664\n", + "GM12878 66.9196424484253\n" ] } ], @@ -1120,88 +891,29 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 118, "metadata": {}, "outputs": [], "source": [ - "a = np.random.choice(1210796, size=128)\n", - "seta = [full_dataset_encode.get_input(x) for x in a]\n", - "seta[0].shape" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(array([0, 1, 2, 3]), array([2804551, 498433, 34145, 100851]))" - ] - }, - "execution_count": 27, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "np.unique(full_dataset_encode._metadata_df['split'], return_counts=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([0, 0, 0, ..., 1, 1, 1]) torch.Size([60726])\n", - "(array([0, 1]), array([30426, 30300]))\n", - "(array([0, 1, 2, 3]), array([28556, 25350, 1702, 5118]))\n" - ] - } - ], - "source": [ - "full_dataset = copy.deepcopy(full_dataset_encode)\n", - "print(full_dataset._y_array, full_dataset._y_array.shape)\n", - "print(np.unique(full_dataset.y_array.numpy(), return_counts=True))\n", - "print(np.unique(full_dataset._metadata_df['split'], return_counts=True))\n", - "\n", - "#full_dataset._input_array" + "# full_dataset = copy.deepcopy(full_dataset_encode)\n", + "full_dataset = copy.deepcopy(full_dataset_camelyon17)\n", + "# full_dataset_camelyon17.split_dict" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 39, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([0, 0, 0, ..., 0, 0, 0]) torch.Size([5568233])\n", - "(array([0, 1]), array([5537933, 30300]))\n", - "(array([0, 1, 2, 3]), array([2533595, 2163528, 437124, 433986]))\n" - ] - } - ], + "outputs": [], "source": [ - "print(full_dataset._y_array, full_dataset._y_array.shape)\n", - "print(np.unique(full_dataset.y_array.numpy(), return_counts=True))\n", - "print(np.unique(full_dataset._metadata_df['split'], return_counts=True))\n", - "\n", - "#full_dataset._input_array\n", - "\n", - "#full_dataset_encode._seq_bp['chr11'].shape\n", - "full_dataset_encode._dnase_allcelltypes['HCT116']['chr11'].shape" + "a = np.random.choice(1210796, size=128)\n", + "seta = [full_dataset_encode.get_input(x) for x in a]\n", + "seta[0].shape" ] }, { "cell_type": "code", - "execution_count": 68, + "execution_count": 111, "metadata": {}, "outputs": [], "source": [ @@ -1216,22 +928,24 @@ }, { "cell_type": "code", - "execution_count": 118, + "execution_count": 104, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 104, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "# full_dataset = copy.deepcopy(full_dataset_encode)\n", - "full_dataset = copy.deepcopy(full_dataset_camelyon17)\n", - "# full_dataset_camelyon17.split_dict" + "full_dataset" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, { "cell_type": "markdown", "metadata": {}, @@ -1241,26 +955,56 @@ }, { "cell_type": "code", - "execution_count": 120, - "metadata": { - "collapsed": true, - "jupyter": { - "outputs_hidden": true - } - }, + "execution_count": 113, + "metadata": {}, "outputs": [ { - "ename": "ValueError", - "evalue": "I/O operation on closed file", + "name": "stdout", + "output_type": "stream", + "text": [ + "Train data...\n", + " hospital = 0: n = 53425\n", + " hospital = 1: n = 0\n", + " hospital = 2: n = 0\n", + " hospital = 3: n = 116959\n", + " hospital = 4: n = 132052\n", + "Validation (ID) data...\n", + " hospital = 0: n = 6011\n", + " hospital = 1: n = 0\n", + " hospital = 2: n = 0\n", + " hospital = 3: n = 12879\n", + " hospital = 4: n = 14670\n", + "Test data...\n", + " hospital = 0: n = 0\n", + " hospital = 1: n = 0\n", + " hospital = 2: n = 85054\n", + " hospital = 3: n = 0\n", + " hospital = 4: n = 0\n", + "Validation (OOD) data...\n", + " hospital = 0: n = 0\n", + " hospital = 1: n = 34904\n", + " hospital = 2: n = 0\n", + " hospital = 3: n = 0\n", + " hospital = 4: n = 0\n", + "Dout: 2\n" + ] + }, + { + "ename": "RuntimeError", + "evalue": "CUDA error: out of memory", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 57\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[0mlog_grouper\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_grouper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 59\u001b[0;31m \u001b[0mlog_group_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdatasets\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlog_grouper\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlogger\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 60\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 61\u001b[0m \u001b[0;31m## Initialize algorithm\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/wilds/examples/utils.py\u001b[0m in \u001b[0;36mlog_group_data\u001b[0;34m(datasets, grouper, logger)\u001b[0m\n\u001b[1;32m 68\u001b[0m \u001b[0mname\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'name'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 69\u001b[0m \u001b[0mdataset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'dataset'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 70\u001b[0;31m \u001b[0mlogger\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwrite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'{name} data...\\n'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 71\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mgrouper\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[0mlogger\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwrite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf' n = {len(dataset)}\\n'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/wilds/examples/utils.py\u001b[0m in \u001b[0;36mwrite\u001b[0;34m(self, msg)\u001b[0m\n\u001b[1;32m 97\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 98\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mwrite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 99\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconsole\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwrite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 100\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfile\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 101\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfile\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwrite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/ipykernel/iostream.py\u001b[0m in \u001b[0;36mwrite\u001b[0;34m(self, string)\u001b[0m\n\u001b[1;32m 392\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 393\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpub_thread\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 394\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'I/O operation on closed file'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 395\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 396\u001b[0m \u001b[0;31m# Make sure that we're handling unicode\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mValueError\u001b[0m: I/O operation on closed file" + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 60\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 61\u001b[0m \u001b[0;31m## Initialize algorithm\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 62\u001b[0;31m algorithm = initialize_algorithm(\n\u001b[0m\u001b[1;32m 63\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 64\u001b[0m \u001b[0mdatasets\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdatasets\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/wilds/examples/algorithms/initializer.py\u001b[0m in \u001b[0;36minitialize_algorithm\u001b[0;34m(config, datasets, train_grouper)\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0malgorithm\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m'ERM'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 32\u001b[0;31m algorithm = ERM(\n\u001b[0m\u001b[1;32m 33\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[0md_out\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0md_out\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/wilds/examples/algorithms/ERM.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, config, d_out, grouper, loss, metric, n_train_steps)\u001b[0m\n\u001b[1;32m 6\u001b[0m def __init__(self, config, d_out, grouper, loss,\n\u001b[1;32m 7\u001b[0m metric, n_train_steps):\n\u001b[0;32m----> 8\u001b[0;31m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minitialize_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0md_out\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 9\u001b[0m \u001b[0;31m# initialize module\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m super().__init__(\n", + "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36mto\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 610\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_floating_point\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_blocking\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 611\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 612\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconvert\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 613\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 614\u001b[0m def register_backward_hook(\n", + "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m 357\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 358\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 359\u001b[0;31m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 360\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 361\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m 357\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 358\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 359\u001b[0;31m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 360\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 361\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m 379\u001b[0m \u001b[0;31m# `with torch.no_grad():`\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 380\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mno_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 381\u001b[0;31m \u001b[0mparam_applied\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparam\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 382\u001b[0m \u001b[0mshould_use_set_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparam\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparam_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 383\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mshould_use_set_data\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36mconvert\u001b[0;34m(t)\u001b[0m\n\u001b[1;32m 608\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mconvert_to_format\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m4\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 609\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_floating_point\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_blocking\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmemory_format\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mconvert_to_format\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 610\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_floating_point\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_blocking\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 611\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 612\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconvert\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mRuntimeError\u001b[0m: CUDA error: out of memory" ] } ], @@ -1334,29 +1078,203 @@ }, { "cell_type": "code", - "execution_count": 135, + "execution_count": 91, "metadata": {}, "outputs": [ { "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
chrstartstopycelltypesplit
3831225chr11917992501917994500H1-hESC1
4190052chr12097406002097408000H1-hESC1
7241915chr866306500663067000H1-hESC1
21449377chr238487450384876500H1-hESC0
45876013chr9569770056979000H1-hESC0
.....................
8841297chr81462777501462779501GM128782
8841298chr81462778001462780001GM128782
8841299chr81462778501462780501GM128782
8841300chr81462779001462781001GM128782
8841301chr81462779501462781501GM128782
\n", + "

1210796 rows × 6 columns

\n", + "
" + ], "text/plain": [ - "device(type='cuda', index=0)" + " chr start stop y celltype split\n", + "3831225 chr1 191799250 191799450 0 H1-hESC 1\n", + "4190052 chr1 209740600 209740800 0 H1-hESC 1\n", + "7241915 chr8 66306500 66306700 0 H1-hESC 1\n", + "21449377 chr2 38487450 38487650 0 H1-hESC 0\n", + "45876013 chr9 5697700 5697900 0 H1-hESC 0\n", + "... ... ... ... .. ... ...\n", + "8841297 chr8 146277750 146277950 1 GM12878 2\n", + "8841298 chr8 146277800 146278000 1 GM12878 2\n", + "8841299 chr8 146277850 146278050 1 GM12878 2\n", + "8841300 chr8 146277900 146278100 1 GM12878 2\n", + "8841301 chr8 146277950 146278150 1 GM12878 2\n", + "\n", + "[1210796 rows x 6 columns]" ] }, - "execution_count": 135, + "execution_count": 91, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "algorithm.device\n", + "# algorithm.device\n", + "_metadata_df\n", "# datasets['train']['loader']" ] }, { "cell_type": "code", - "execution_count": 134, + "execution_count": 90, "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'datasets' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdatasets\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'train'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'loader'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmetadata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mNameError\u001b[0m: name 'datasets' is not defined" + ] + } + ], + "source": [ + "for batch in datasets['train']['loader']:\n", + " x, y_true, metadata = batch\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": 134, + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + } + }, "outputs": [ { "ename": "RuntimeError", @@ -1387,17 +1305,6 @@ "algorithm.model(x.to(algorithm.device))" ] }, - { - "cell_type": "code", - "execution_count": 131, - "metadata": {}, - "outputs": [], - "source": [ - "for batch in datasets['train']['loader']:\n", - " x, y_true, metadata = batch\n", - " break" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -1580,13 +1487,13 @@ }, { "cell_type": "code", - "execution_count": 124, + "execution_count": 100, "metadata": {}, "outputs": [], "source": [ "def double_conv(in_channels, out_channels): \n", " return nn.Sequential(\n", - " nn.Conv1d(in_channels, out_channels, 7, padding=3), \n", + " nn.Conv1d(in_channels, out_channels, 7, padding=2), \n", " nn.BatchNorm1d(out_channels), \n", " nn.ReLU(inplace=True),\n", " nn.Conv1d(out_channels, out_channels, 7, padding=3), \n", @@ -1665,110 +1572,16 @@ }, { "cell_type": "code", - "execution_count": 125, - "metadata": { - "collapsed": true, - "jupyter": { - "outputs_hidden": true - } - }, + "execution_count": 101, + "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "[('dconv_down1.0.weight', 630),\n", - " ('dconv_down1.0.bias', 15),\n", - " ('dconv_down1.1.weight', 15),\n", - " ('dconv_down1.1.bias', 15),\n", - " ('dconv_down1.3.weight', 1575),\n", - " ('dconv_down1.3.bias', 15),\n", - " ('dconv_down1.4.weight', 15),\n", - " ('dconv_down1.4.bias', 15),\n", - " ('dconv_down2.0.weight', 2310),\n", - " ('dconv_down2.0.bias', 22),\n", - " ('dconv_down2.1.weight', 22),\n", - " ('dconv_down2.1.bias', 22),\n", - " ('dconv_down2.3.weight', 3388),\n", - " ('dconv_down2.3.bias', 22),\n", - " ('dconv_down2.4.weight', 22),\n", - " ('dconv_down2.4.bias', 22),\n", - " ('dconv_down3.0.weight', 5082),\n", - " ('dconv_down3.0.bias', 33),\n", - " ('dconv_down3.1.weight', 33),\n", - " ('dconv_down3.1.bias', 33),\n", - " ('dconv_down3.3.weight', 7623),\n", - " ('dconv_down3.3.bias', 33),\n", - " ('dconv_down3.4.weight', 33),\n", - " ('dconv_down3.4.bias', 33),\n", - " ('dconv_down4.0.weight', 11319),\n", - " ('dconv_down4.0.bias', 49),\n", - " ('dconv_down4.1.weight', 49),\n", - " ('dconv_down4.1.bias', 49),\n", - " ('dconv_down4.3.weight', 16807),\n", - " ('dconv_down4.3.bias', 49),\n", - " ('dconv_down4.4.weight', 49),\n", - " ('dconv_down4.4.bias', 49),\n", - " ('dconv_down5.0.weight', 25039),\n", - " ('dconv_down5.0.bias', 73),\n", - " ('dconv_down5.1.weight', 73),\n", - " ('dconv_down5.1.bias', 73),\n", - " ('dconv_down5.3.weight', 37303),\n", - " ('dconv_down5.3.bias', 73),\n", - " ('dconv_down5.4.weight', 73),\n", - " ('dconv_down5.4.bias', 73),\n", - " ('dconv_down6.0.weight', 55699),\n", - " ('dconv_down6.0.bias', 109),\n", - " ('dconv_down6.1.weight', 109),\n", - " ('dconv_down6.1.bias', 109),\n", - " ('dconv_down6.3.weight', 83167),\n", - " ('dconv_down6.3.bias', 109),\n", - " ('dconv_down6.4.weight', 109),\n", - " ('dconv_down6.4.bias', 109),\n", - " ('dconv_up5.0.weight', 93002),\n", - " ('dconv_up5.0.bias', 73),\n", - " ('dconv_up5.1.weight', 73),\n", - " ('dconv_up5.1.bias', 73),\n", - " ('dconv_up5.3.weight', 37303),\n", - " ('dconv_up5.3.bias', 73),\n", - " ('dconv_up5.4.weight', 73),\n", - " ('dconv_up5.4.bias', 73),\n", - " ('dconv_up4.0.weight', 41846),\n", - " ('dconv_up4.0.bias', 49),\n", - " ('dconv_up4.1.weight', 49),\n", - " ('dconv_up4.1.bias', 49),\n", - " ('dconv_up4.3.weight', 16807),\n", - " ('dconv_up4.3.bias', 49),\n", - " ('dconv_up4.4.weight', 49),\n", - " ('dconv_up4.4.bias', 49),\n", - " ('dconv_up3.0.weight', 18942),\n", - " ('dconv_up3.0.bias', 33),\n", - " ('dconv_up3.1.weight', 33),\n", - " ('dconv_up3.1.bias', 33),\n", - " ('dconv_up3.3.weight', 7623),\n", - " ('dconv_up3.3.bias', 33),\n", - " ('dconv_up3.4.weight', 33),\n", - " ('dconv_up3.4.bias', 33),\n", - " ('dconv_up2.0.weight', 8470),\n", - " ('dconv_up2.0.bias', 22),\n", - " ('dconv_up2.1.weight', 22),\n", - " ('dconv_up2.1.bias', 22),\n", - " ('dconv_up2.3.weight', 3388),\n", - " ('dconv_up2.3.bias', 22),\n", - " ('dconv_up2.4.weight', 22),\n", - " ('dconv_up2.4.bias', 22),\n", - " ('dconv_up1.0.weight', 3885),\n", - " ('dconv_up1.0.bias', 15),\n", - " ('dconv_up1.1.weight', 15),\n", - " ('dconv_up1.1.bias', 15),\n", - " ('dconv_up1.3.weight', 1575),\n", - " ('dconv_up1.3.bias', 15),\n", - " ('dconv_up1.4.weight', 15),\n", - " ('dconv_up1.4.bias', 15),\n", - " ('conv_last.weight', 30),\n", - " ('conv_last.bias', 2)]" + "485773" ] }, - "execution_count": 125, + "execution_count": 101, "metadata": {}, "output_type": "execute_result" } @@ -1779,58 +1592,141 @@ "\n", "lst = [(x[0], x[1].numel()) for x in model.named_parameters()]\n", "#np.sum([x[1] for x in lst])\n", - "count_parameters(model)\n", - "lst" + "count_parameters(model)" ] }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 102, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "[('conv1.weight', 4750),\n", - " ('conv1.bias', 50),\n", - " ('conv2.weight', 27500),\n", - " ('conv2.bias', 50),\n", - " ('conv3.weight', 17500),\n", - " ('conv3.bias', 50),\n", - " ('bn1.weight', 50),\n", - " ('bn1.bias', 50),\n", - " ('bn2.weight', 50),\n", - " ('bn2.bias', 50),\n", - " ('bn3.weight', 50),\n", - " ('bn3.bias', 50),\n", - " ('fc1.weight', 4200000),\n", - " ('fc1.bias', 1000),\n", - " ('bn4.weight', 1000),\n", - " ('bn4.bias', 1000),\n", - " ('fc2.weight', 1000000),\n", - " ('fc2.bias', 1000),\n", - " ('bn5.weight', 1000),\n", - " ('bn5.bias', 1000),\n", - " ('fc3.weight', 1000),\n", - " ('fc3.bias', 1)]" + "UNet(\n", + " (dconv_down1): Sequential(\n", + " (0): Conv1d(6, 15, kernel_size=(7,), stride=(1,), padding=(2,))\n", + " (1): BatchNorm1d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU(inplace=True)\n", + " (3): Conv1d(15, 15, kernel_size=(7,), stride=(1,), padding=(3,))\n", + " (4): BatchNorm1d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (5): ReLU(inplace=True)\n", + " )\n", + " (dconv_down2): Sequential(\n", + " (0): Conv1d(15, 22, kernel_size=(7,), stride=(1,), padding=(2,))\n", + " (1): BatchNorm1d(22, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU(inplace=True)\n", + " (3): Conv1d(22, 22, kernel_size=(7,), stride=(1,), padding=(3,))\n", + " (4): BatchNorm1d(22, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (5): ReLU(inplace=True)\n", + " )\n", + " (dconv_down3): Sequential(\n", + " (0): Conv1d(22, 33, kernel_size=(7,), stride=(1,), padding=(2,))\n", + " (1): BatchNorm1d(33, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU(inplace=True)\n", + " (3): Conv1d(33, 33, kernel_size=(7,), stride=(1,), padding=(3,))\n", + " (4): BatchNorm1d(33, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (5): ReLU(inplace=True)\n", + " )\n", + " (dconv_down4): Sequential(\n", + " (0): Conv1d(33, 49, kernel_size=(7,), stride=(1,), padding=(2,))\n", + " (1): BatchNorm1d(49, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU(inplace=True)\n", + " (3): Conv1d(49, 49, kernel_size=(7,), stride=(1,), padding=(3,))\n", + " (4): BatchNorm1d(49, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (5): ReLU(inplace=True)\n", + " )\n", + " (dconv_down5): Sequential(\n", + " (0): Conv1d(49, 73, kernel_size=(7,), stride=(1,), padding=(2,))\n", + " (1): BatchNorm1d(73, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU(inplace=True)\n", + " (3): Conv1d(73, 73, kernel_size=(7,), stride=(1,), padding=(3,))\n", + " (4): BatchNorm1d(73, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (5): ReLU(inplace=True)\n", + " )\n", + " (dconv_down6): Sequential(\n", + " (0): Conv1d(73, 109, kernel_size=(7,), stride=(1,), padding=(2,))\n", + " (1): BatchNorm1d(109, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU(inplace=True)\n", + " (3): Conv1d(109, 109, kernel_size=(7,), stride=(1,), padding=(3,))\n", + " (4): BatchNorm1d(109, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (5): ReLU(inplace=True)\n", + " )\n", + " (maxpool): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", + " (upsample): Upsample(scale_factor=2.0, mode=bilinear)\n", + " (dconv_up5): Sequential(\n", + " (0): Conv1d(182, 73, kernel_size=(7,), stride=(1,), padding=(2,))\n", + " (1): BatchNorm1d(73, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU(inplace=True)\n", + " (3): Conv1d(73, 73, kernel_size=(7,), stride=(1,), padding=(3,))\n", + " (4): BatchNorm1d(73, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (5): ReLU(inplace=True)\n", + " )\n", + " (dconv_up4): Sequential(\n", + " (0): Conv1d(122, 49, kernel_size=(7,), stride=(1,), padding=(2,))\n", + " (1): BatchNorm1d(49, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU(inplace=True)\n", + " (3): Conv1d(49, 49, kernel_size=(7,), stride=(1,), padding=(3,))\n", + " (4): BatchNorm1d(49, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (5): ReLU(inplace=True)\n", + " )\n", + " (dconv_up3): Sequential(\n", + " (0): Conv1d(82, 33, kernel_size=(7,), stride=(1,), padding=(2,))\n", + " (1): BatchNorm1d(33, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU(inplace=True)\n", + " (3): Conv1d(33, 33, kernel_size=(7,), stride=(1,), padding=(3,))\n", + " (4): BatchNorm1d(33, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (5): ReLU(inplace=True)\n", + " )\n", + " (dconv_up2): Sequential(\n", + " (0): Conv1d(55, 22, kernel_size=(7,), stride=(1,), padding=(2,))\n", + " (1): BatchNorm1d(22, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU(inplace=True)\n", + " (3): Conv1d(22, 22, kernel_size=(7,), stride=(1,), padding=(3,))\n", + " (4): BatchNorm1d(22, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (5): ReLU(inplace=True)\n", + " )\n", + " (dconv_up1): Sequential(\n", + " (0): Conv1d(37, 15, kernel_size=(7,), stride=(1,), padding=(2,))\n", + " (1): BatchNorm1d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU(inplace=True)\n", + " (3): Conv1d(15, 15, kernel_size=(7,), stride=(1,), padding=(3,))\n", + " (4): BatchNorm1d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (5): ReLU(inplace=True)\n", + " )\n", + " (conv_last): Conv2d(15, 2, kernel_size=(1, 1), stride=(1, 1))\n", + ")" ] }, - "execution_count": 34, + "execution_count": 102, "metadata": {}, "output_type": "execute_result" } ], + "source": [ + "model" + ] + }, + { + "cell_type": "code", + "execution_count": 96, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'Beagle' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mp\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparameters\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrequires_grad\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mBeagle\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0;31m#model = DanQ(50, 5)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mNameError\u001b[0m: name 'Beagle' is not defined" + ] + } + ], "source": [ "def count_parameters(model):\n", - " return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", - "\n", - "model = Beagle()\n", - "#model = DanQ(50, 5)\n", - "\n", - "lst = [(x[0], x[1].numel()) for x in model.named_parameters()]\n", - "#np.sum([x[1] for x in lst])\n", - "count_parameters(model)\n", - "lst" + " return sum(p.numel() for p in model.parameters() if p.requires_grad)" ] }, { From f006f13e47b1c6167c475fb3669dd3ff1faf476f Mon Sep 17 00:00:00 2001 From: aikanor Date: Mon, 1 Mar 2021 08:08:05 -0800 Subject: [PATCH 026/244] check in training/eval --- examples/configs/supported.py | 3 +- examples/models/CNN_genome.py | 18 +- examples/sbox_run_expt.ipynb | 1733 +++++++++++++++++++------- wilds/datasets/encodetfbs_dataset.py | 77 +- 4 files changed, 1345 insertions(+), 486 deletions(-) diff --git a/examples/configs/supported.py b/examples/configs/supported.py index fd4d6a63..bdf73267 100644 --- a/examples/configs/supported.py +++ b/examples/configs/supported.py @@ -48,8 +48,7 @@ # see initialize_*() functions for correspondence transforms = ['bert', 'image_base', 'image_resize_and_center_crop', 'poverty_train'] -models = ['resnet18_ms', 'resnet50', 'resnet34', 'wideresnet50', 'densenet121', 'bert-base-uncased', 'gin-virtual', - 'logistic_regression', 'beagle'] +models = ['resnet18_ms', 'resnet50', 'resnet34', 'wideresnet50', 'densenet121', 'bert-base-uncased', 'gin-virtual', 'logistic_regression', 'leopard'] algorithms = ['ERM', 'groupDRO', 'deepCORAL', 'IRM'] optimizers = ['SGD', 'Adam', 'AdamW'] schedulers = ['linear_schedule_with_warmup', 'ReduceLROnPlateau', 'StepLR'] diff --git a/examples/models/CNN_genome.py b/examples/models/CNN_genome.py index f1b90d07..147f8c9e 100644 --- a/examples/models/CNN_genome.py +++ b/examples/models/CNN_genome.py @@ -6,7 +6,14 @@ -def double_conv(in_channels, out_channels): +def single_conv(in_channels, out_channels): + return nn.Sequential( + nn.Conv1d(in_channels, out_channels, 7, padding=3), + nn.BatchNorm1d(out_channels), + nn.ReLU(inplace=True) + ) + +def double_conv(in_channels, out_channels): return nn.Sequential( nn.Conv1d(in_channels, out_channels, 7, padding=3), nn.BatchNorm1d(out_channels), @@ -19,10 +26,10 @@ def double_conv(in_channels, out_channels): class UNet(nn.Module): - def __init__(self, n_class): + def __init__(self, n_class, n_channels_in=6): super().__init__() - self.dconv_down1 = double_conv(6, 15) + self.dconv_down1 = double_conv(n_channels_in, 15) self.dconv_down2 = double_conv(15, 22) self.dconv_down3 = double_conv(22, 33) self.dconv_down4 = double_conv(33, 49) @@ -30,7 +37,8 @@ def __init__(self, n_class): self.dconv_down6 = double_conv(73, 109) self.maxpool = nn.MaxPool1d(2) - self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + self.conv_middle = single_conv(109, 109) self.dconv_up5 = double_conv(73 + 109, 73) self.dconv_up4 = double_conv(49 + 73, 49) @@ -60,6 +68,8 @@ def forward(self, x): x = self.maxpool(conv5) # (input_size / 32) x 73 conv6 = self.dconv_down6(x) # (input_size / 32) x 109 + # conv6 = self.conv_middle(conv6) # Optional: convolution here. + # Encoder finished. x = self.upsample(conv6) # (input_size / 16) x 109 diff --git a/examples/sbox_run_expt.ipynb b/examples/sbox_run_expt.ipynb index 66712a29..06440dc6 100644 --- a/examples/sbox_run_expt.ipynb +++ b/examples/sbox_run_expt.ipynb @@ -34,7 +34,33 @@ "cell_type": "code", "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "163 µs ± 343 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n" + ] + } + ], + "source": [ + "import pyBigWig\n", + "# %timeit bw = pyBigWig.open(\"/users/abalsubr/wilds/examples/data/encode-tfbs_v1.0/DNASE.K562.fc.signal.bigwig\")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:root:The OGB package is out of date. Your version is 1.2.4, while the latest version is 1.2.5.\n" + ] + } + ], "source": [ "import os, csv\n", "import time\n", @@ -59,16 +85,16 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" + "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" ] }, - "execution_count": 2, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -76,7 +102,7 @@ "source": [ "''' set default hyperparams in default_hyperparams.py '''\n", "parser = argparse.ArgumentParser()\n", - "\n", + "CombinatorialGrouper\n", "# Required arguments\n", "parser.add_argument('-d', '--dataset', choices=supported.datasets, required=True)\n", "parser.add_argument('--algorithm', required=True, choices=supported.algorithms)\n", @@ -163,7 +189,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -175,27 +201,20 @@ "config_encode = parser.parse_args(argstr_encode.split())\n", "config_encode = populate_defaults(config_encode)\n", "\n", - "config = config_camelyon\n", - "# config = config_encode" + "# config = config_camelyon\n", + "config = config_encode\n" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:root:The OGB package is out of date. Your version is 1.2.4, while the latest version is 1.2.5.\n" - ] - }, { "name": "stdout", "output_type": "stream", "text": [ - "Dataset: camelyon17\n", + "Dataset: encode-tfbs\n", "Algorithm: ERM\n", "Root dir: data\n", "Split scheme: official\n", @@ -207,26 +226,26 @@ "Uniform over groups: False\n", "Distinct groups: None\n", "N groups per batch: 2\n", - "Batch size: 32\n", + "Batch size: 64\n", "Eval loader: standard\n", - "Model: densenet121\n", + "Model: leopard\n", "Model kwargs: {'pretrained': False}\n", - "Train transform: image_base\n", - "Eval transform: image_base\n", - "Target resolution: (224, 224)\n", + "Train transform: None\n", + "Eval transform: None\n", + "Target resolution: None\n", "Resize scale: None\n", "Max token length: None\n", "Loss function: cross_entropy\n", - "Groupby fields: ['hospital']\n", + "Groupby fields: ['celltype', 'y']\n", "Group dro step size: None\n", - "Coral penalty weight: 0.1\n", - "Irm lambda: 1.0\n", + "Coral penalty weight: None\n", + "Irm lambda: None\n", "Irm penalty anneal iters: None\n", "Algo log metric: accuracy\n", "Val metric: acc_avg\n", "Val metric decreasing: False\n", "N epochs: 5\n", - "Optimizer: SGD\n", + "Optimizer: Adam\n", "Lr: 0.001\n", "Weight decay: 0.01\n", "Max grad norm: None\n", @@ -250,7 +269,42 @@ "Use wandb: False\n", "Progress bar: False\n", "Resume: False\n", - "\n" + "\n", + "chr2 3.6633927822113037\n", + "chr3 6.7115819454193115\n", + "chr4 9.648637771606445\n", + "chr5 12.439441919326782\n", + "chr6 15.091757774353027\n", + "chr7 17.542895555496216\n", + "chr9 19.707583904266357\n", + "chr10 21.79905652999878\n", + "chr11 23.86957049369812\n", + "chr12 25.918642044067383\n", + "chr13 27.675577402114868\n", + "chr14 29.3148353099823\n", + "chr15 30.881144046783447\n", + "chr16 32.271193504333496\n", + "chr17 33.51785063743591\n", + "chr18 34.72123050689697\n", + "chr19 35.627156257629395\n", + "chr20 36.59872794151306\n", + "chr22 37.37847852706909\n", + "chrX 39.77280807495117\n", + "chr1 43.60475468635559\n", + "chr8 45.86070203781128\n", + "chr21 46.59553360939026\n" + ] + }, + { + "ename": "NameError", + "evalue": "name '_all_celltypes' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0;31m# Data\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 26\u001b[0;31m full_dataset = supported.datasets[config.dataset](\n\u001b[0m\u001b[1;32m 27\u001b[0m \u001b[0mroot_dir\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mroot_dir\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0mdownload\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdownload\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/wilds/wilds/datasets/encodetfbs_dataset.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, root_dir, download, split_scheme)\u001b[0m\n\u001b[1;32m 101\u001b[0m \u001b[0;31m# Get the y values, and remove ambiguous labels by default.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 102\u001b[0m \u001b[0mpd_list\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 103\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mct\u001b[0m \u001b[0;32min\u001b[0m \u001b[0m_all_celltypes\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 104\u001b[0m \u001b[0mtc_chr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mall_df\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'chr'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'start'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'stop'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mct\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 105\u001b[0m \u001b[0mtc_chr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcolumns\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m'chr'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'start'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'stop'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'y'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mNameError\u001b[0m: name '_all_celltypes' is not defined" ] } ], @@ -300,15 +354,47 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 7, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "import copy\n", "full_dataset_camelyon17 = copy.deepcopy(full_dataset)\n", "\n", "# supported.datasets[config_encode.dataset]\n", - "# print(config_camelyon.train_transform, config_encode.train_transform)" + "# print(config_camelyon.train_transform, config_encode.train_transform)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'full_dataset' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mfull_dataset\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mNameError\u001b[0m: name 'full_dataset' is not defined" + ] + } + ], + "source": [ + "full_dataset" ] }, { @@ -320,26 +406,26 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "chr2 3.764267683029175\n", - "chr9 5.914910078048706\n", - "chr11 7.964999675750732\n", - "chr1 11.748822927474976\n", - "chr8 14.01279878616333\n", - "chr21 14.737261772155762\n", - "H1-hESC 14.73790693283081\n", - "HCT116 14.737961292266846\n", - "HeLa-S3 14.737993240356445\n", - "HepG2 14.738024950027466\n", - "K562 14.73805570602417\n", - "A549 14.738086223602295\n", - "GM12878 14.738116979598999\n" + "chr2 3.90254282951355\n", + "chr9 6.149690628051758\n", + "chr11 8.327073097229004\n", + "chr1 12.291624546051025\n", + "chr8 14.624409675598145\n", + "chr21 15.413429021835327\n", + "H1-hESC 15.415196895599365\n", + "HCT116 15.415941953659058\n", + "HeLa-S3 15.416455030441284\n", + "HepG2 15.417592763900757\n", + "K562 15.418397426605225\n", + "A549 15.41891360282898\n", + "GM12878 15.419732332229614\n" ] } ], @@ -405,28 +491,32 @@ " print(chrom, time.time() - itime)\n", "\n", "_dnase_allcelltypes = {}\n", + "ct = 'avg'\n", + "dnase_avg_bw_path = os.path.join(_data_dir, 'Leopard_dnase/{}.bigwig'.format(ct))\n", + "_dnase_allcelltypes[ct] = pyBigWig.open(dnase_avg_bw_path)\n", "for ct in _all_celltypes:\n", " \"\"\"\n", - " dnase_filename = os.path.join(_data_dir, '{}_dnase.npz'.format(ct))\n", + " dnase_filename = os.path.join(self._data_dir, '{}_dnase.npz'.format(ct))\n", " dnase_npz_contents = np.load(dnase_filename)\n", - " _dnase_allcelltypes[ct] = {}\n", - " for chrom in _all_chroms: #_seq_bp:\n", - " _dnase_allcelltypes[ct][chrom] = dnase_npz_contents[chrom]\n", + " self._dnase_allcelltypes[ct] = {}\n", + " for chrom in self._all_chroms: #self._seq_bp:\n", + " self._dnase_allcelltypes[ct][chrom] = dnase_npz_contents[chrom]\n", " \"\"\"\n", - " _dnase_allcelltypes[ct] = 'DNASE.{}.fc.signal.bigwig'\n", + " dnase_bw_path = os.path.join(_data_dir, 'Leopard_dnase/{}.bigwig'.format(ct))\n", + " _dnase_allcelltypes[ct] = pyBigWig.open(dnase_bw_path)\n", " print(ct, time.time() - itime)" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "66.32568740844727\n" + "74.06488299369812\n" ] } ], @@ -445,33 +535,33 @@ }, { "cell_type": "code", - "execution_count": 59, + "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - ":12: SettingWithCopyWarning: \n", + ":8: SettingWithCopyWarning: \n", "A value is trying to be set on a copy of a slice from a DataFrame.\n", "Try using .loc[row_indexer,col_indexer] = value instead\n", "\n", "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", - " tc_chr['y'] = y_array\n" + " tc_chr['y'] = tc_chr['y'].replace({'U': 0, 'B': 1, 'A': 0.5}).values\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "11.363114833831787\n", - "21.872379302978516\n", - "32.51760506629944\n", - "42.88175559043884\n", - "53.35902285575867\n", - "63.94557332992554\n", - "74.44822382926941\n", - "92.237633228302\n" + "13.203907012939453\n", + "22.746795415878296\n", + "32.076903104782104\n", + "41.43525505065918\n", + "51.37267017364502\n", + "61.07364773750305\n", + "71.10506510734558\n", + "100.72125267982483\n" ] } ], @@ -483,12 +573,12 @@ "for ct in _all_celltypes:\n", " tc_chr = all_df[['chr', 'start', 'stop', ct]]\n", " tc_chr.columns = ['chr', 'start', 'stop', 'y']\n", - " y_array = tc_chr['y'].replace({'U': 0, 'B': 1, 'A': -1}).values\n", + " tc_chr['y'] = tc_chr['y'].replace({'U': 0, 'B': 1, 'A': 0.5}).values\n", " \n", - " # Now filter out ambiguous labels\n", - " non_ambig_mask = (y_array != -1)\n", - " tc_chr['y'] = y_array\n", - " tc_chr = tc_chr[non_ambig_mask]\n", + " # # Now filter out ambiguous labels\n", + " # non_ambig_mask = (y_array != -1)\n", + " # tc_chr['y'] = y_array\n", + " # tc_chr = tc_chr[non_ambig_mask]\n", " \n", " tc_chr.insert(len(tc_chr.columns), 'celltype', ct)\n", " pd_list.append(tc_chr)\n", @@ -497,101 +587,440 @@ "\n", "print(time.time() - itime)\n", "\n", - "# y_array = metadata_df['y'].replace({'U': 0, 'B': 1, 'A': -1}).values\n", - "# non_ambig_mask = (y_array != -1)\n", - "# metadata_df['y'] = y_array\n", - "# _metadata_df = metadata_df[non_ambig_mask]\n", + "_metadata_df = metadata_df\n" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "H1-hESC 8.80908489227295\n", + "H1-hESC 12.474784135818481\n", + "H1-hESC 15.258162498474121\n", + "HCT116 23.023517370224\n", + "HCT116 25.439095735549927\n", + "HCT116 27.21790099143982\n", + "HeLa-S3 34.81065845489502\n", + "HeLa-S3 35.2776780128479\n", + "HeLa-S3 36.60090255737305\n", + "HepG2 44.36072087287903\n", + "HepG2 44.74501991271973\n", + "HepG2 46.02569603919983\n", + "K562 53.825233697891235\n", + "K562 54.182188749313354\n", + "K562 55.44522547721863\n", + "A549 62.980581283569336\n", + "A549 63.34522008895874\n", + "A549 64.59721446037292\n", + "GM12878 72.41460752487183\n", + "GM12878 72.7955391407013\n", + "GM12878 74.05369997024536\n" + ] + } + ], + "source": [ + "# np.unique(_metadata_df['y'])\n", "\n", - "# print(time.time() - itime)" + "# Downsample negatives to balance each celltype\n", + "samp_ndces = []\n", + "itime = time.time()\n", + "neg_msk = (_metadata_df['y'] == 0)\n", + "pos_msk = (_metadata_df['y'] != 0)\n", + "for ct in _all_celltypes:\n", + " celltype_msk = (_metadata_df['celltype'] == ct)\n", + " print(ct, time.time() - itime)\n", + " neg_ct_msk = np.logical_and(celltype_msk, neg_msk)\n", + " pos_ct_msk = np.logical_and(celltype_msk, pos_msk)\n", + " print(ct, time.time() - itime)\n", + " neg_ndces = np.where(neg_ct_msk)[0]\n", + " pos_ndces = np.where(pos_ct_msk)[0]\n", + " np.random.seed(42)\n", + " samp_neg_ndces = np.random.choice(neg_ndces, size=len(pos_ndces), replace=False)\n", + " samp_ndces.extend(samp_neg_ndces)\n", + " samp_ndces.extend(pos_ndces)\n", + " print(ct, time.time() - itime)\n", + "_metadata_df = _metadata_df.iloc[samp_ndces, :]" ] }, { "cell_type": "code", - "execution_count": 42, + "execution_count": 145, "metadata": {}, "outputs": [], "source": [ - "# window_size = 12800\n", - "# window_interval = window_size/2\n", - "# trl_mask = (train_regions_labeled['start']%window_interval == 0)\n", - "# train_regions_labeled[trl_mask]" + "def get_random_label_vec(metadata_df, output_size=128):\n", + " # Sample a positively labeled region at random\n", + " pos_mdf = metadata_df[metadata_df['y'] == 1] #.iloc[ metadata_df['chr'] == s['chr'], : ]\n", + " pos_seed_region = pos_mdf.iloc[np.random.randint(pos_mdf.shape[0])]\n", + "\n", + " # Extract regions from this chromosome in this celltype, to get a window of labels from\n", + " chr_msk = np.array(metadata_df['chr']) == pos_seed_region['chr']\n", + " ct_msk = np.array(metadata_df['celltype']) == pos_seed_region['celltype']\n", + " mdf = metadata_df[chr_msk & ct_msk]\n", + "\n", + " # Get labels\n", + " start_ndx = np.where(mdf['start'] == pos_seed_region['start'])[0][0]\n", + " y_label_vec = mdf.iloc[start_ndx:start_ndx+output_size, :]['y']" + ] + }, + { + "cell_type": "code", + "execution_count": 146, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
chrstartstopycelltype
5937924chr86008000.0HepG2
5937925chr86508500.0HepG2
5937926chr87009000.0HepG2
5937927chr87509500.0HepG2
5937928chr880010000.0HepG2
..................
8843006chr81463632001463634000.0HepG2
8843007chr81463632501463634500.0HepG2
8843008chr81463633001463635000.0HepG2
8843009chr81463633501463635500.0HepG2
8843010chr81463634001463636000.0HepG2
\n", + "

2905087 rows × 5 columns

\n", + "
" + ], + "text/plain": [ + " chr start stop y celltype\n", + "5937924 chr8 600 800 0.0 HepG2\n", + "5937925 chr8 650 850 0.0 HepG2\n", + "5937926 chr8 700 900 0.0 HepG2\n", + "5937927 chr8 750 950 0.0 HepG2\n", + "5937928 chr8 800 1000 0.0 HepG2\n", + "... ... ... ... ... ...\n", + "8843006 chr8 146363200 146363400 0.0 HepG2\n", + "8843007 chr8 146363250 146363450 0.0 HepG2\n", + "8843008 chr8 146363300 146363500 0.0 HepG2\n", + "8843009 chr8 146363350 146363550 0.0 HepG2\n", + "8843010 chr8 146363400 146363600 0.0 HepG2\n", + "\n", + "[2905087 rows x 5 columns]" + ] + }, + "execution_count": 146, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 154, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "8809571 1.0\n", + "8809572 1.0\n", + "8809573 1.0\n", + "8809574 1.0\n", + "8809575 1.0\n", + "8809576 0.5\n", + "8809577 0.5\n", + "8809578 0.5\n", + "8809579 1.0\n", + "8809580 1.0\n", + "8809581 1.0\n", + "8809582 1.0\n", + "8809583 1.0\n", + "8809584 1.0\n", + "8809585 0.5\n", + "8809586 0.5\n", + "8809587 0.0\n", + "8809588 0.0\n", + "8809589 0.0\n", + "8809590 0.0\n", + "8809591 0.0\n", + "8809592 0.0\n", + "8809593 0.0\n", + "8809594 0.0\n", + "8809595 0.0\n", + "8809596 0.0\n", + "8809597 0.0\n", + "8809598 0.0\n", + "8809599 0.0\n", + "8809600 0.0\n", + "8809601 0.0\n", + "8809602 0.0\n", + "Name: y, dtype: float64" + ] + }, + "execution_count": 154, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 107, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([0.])" + ] + }, + "execution_count": 107, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.unique(mdf[:256+start_bin]['y'])" + ] + }, + { + "cell_type": "code", + "execution_count": 150, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "chr chr8\n", + "start 144691450\n", + "stop 144691650\n", + "y 1.0\n", + "celltype HepG2\n", + "Name: 8809571, dtype: object" + ] + }, + "execution_count": 150, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pos_seed_region" ] }, { "cell_type": "code", - "execution_count": 108, + "execution_count": 98, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "686900" + "array([ 600, 600, 600, ..., 135005900, 135005900,\n", + " 135005900])" ] }, - "execution_count": 108, + "execution_count": 98, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "(metadata_df['y'] == 1).sum()\n", - "# pd_list[0][non_ambig_mask]" + "# arr = metadata_df[mdf_msk]['start']\n", + "#arr == \n", + "np.sort(arr)" ] }, { "cell_type": "code", - "execution_count": 88, + "execution_count": 69, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "H1-hESC 8.10781979560852\n", - "H1-hESC 8.47616195678711\n", - "H1-hESC 9.822284698486328\n", - "HCT116 17.048683881759644\n", - "HCT116 17.41142964363098\n", - "HCT116 18.752415657043457\n", - "HeLa-S3 26.464386463165283\n", - "HeLa-S3 26.860748291015625\n", - "HeLa-S3 28.151614665985107\n", - "HepG2 35.439460039138794\n", - "HepG2 35.83507966995239\n", - "HepG2 37.079824924468994\n", - "K562 44.71583318710327\n", - "K562 45.092923164367676\n", - "K562 46.389798402786255\n", - "A549 53.895429372787476\n", - "A549 54.27841639518738\n", - "A549 55.64506816864014\n", - "GM12878 63.17967939376831\n", - "GM12878 63.545384883880615\n", - "GM12878 64.84915113449097\n" + "116.39193439483643\n" ] } ], "source": [ - "# Downsample negatives to balance each celltype\n", - "samp_ndces = []\n", "itime = time.time()\n", - "neg_msk = (_metadata_df['y'] == 0)\n", - "pos_msk = (_metadata_df['y'] == 1)\n", - "for ct in _all_celltypes:\n", - " celltype_msk = (_metadata_df['celltype'] == ct)\n", - " print(ct, time.time() - itime)\n", - " neg_ct_msk = np.logical_and(celltype_msk, neg_msk)\n", - " pos_ct_msk = np.logical_and(celltype_msk, pos_msk)\n", - " print(ct, time.time() - itime)\n", - " neg_ndces = np.where(neg_ct_msk)[0]\n", - " pos_ndces = np.where(pos_ct_msk)[0]\n", - " np.random.seed(42)\n", - " samp_neg_ndces = np.random.choice(neg_ndces, size=len(pos_ndces), replace=False)\n", - " samp_ndces.extend(samp_neg_ndces)\n", - " samp_ndces.extend(pos_ndces)\n", - " print(ct, time.time() - itime)\n", - "_metadata_df = _metadata_df.iloc[samp_ndces, :]\n", - "\n", + "lts = ['{}:{}-{}'.format(x[0], x[1], x[2]) for x in zip(metadata_df['chr'], metadata_df['start'], metadata_df['stop'])]\n", + "print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "202800" + ] + }, + "execution_count": 55, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "#_metadata_df['y']\n", + "# s = metadata_df.iloc[np.array(pos_msk), :]\n", + "ntry = s.iloc[5]\n", + "ntry['start'] + 12800\n", + "# s['chr'], s['start'], s['stop'] # np.unique(s['chr'], return_counts=True)\n", + "# all_df\n", + "# metadata_df" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "H1-hESC 7.871182918548584\n", + "H1-hESC 8.298529148101807\n", + "H1-hESC 9.57175898551941\n", + "HCT116 17.01794719696045\n", + "HCT116 17.36267113685608\n", + "HCT116 18.669682025909424\n", + "HeLa-S3 26.405478954315186\n", + "HeLa-S3 26.759119272232056\n", + "HeLa-S3 28.043395042419434\n", + "HepG2 35.623862981796265\n", + "HepG2 35.98245143890381\n", + "HepG2 37.29869079589844\n", + "K562 44.92080807685852\n", + "K562 45.256179332733154\n", + "K562 46.7364935874939\n", + "A549 54.39264512062073\n", + "A549 54.74424934387207\n", + "A549 56.03351712226868\n", + "GM12878 63.745240211486816\n", + "GM12878 64.1029920578003\n", + "GM12878 65.43286633491516\n" + ] + } + ], + "source": [ "train_regions_mask = np.isin(_metadata_df['chr'], _train_chroms)\n", "val_regions_mask = np.isin(_metadata_df['chr'], _test_chroms)\n", "train_celltype_mask = np.isin(_metadata_df['celltype'], _train_celltypes)\n", @@ -629,17 +1058,13 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Initialize dataset object" + "# Dataset object (long version)" ] }, { "cell_type": "code", - "execution_count": 106, - "metadata": { - "jupyter": { - "source_hidden": true - } - }, + "execution_count": 3, + "metadata": {}, "outputs": [], "source": [ "import os, time\n", @@ -719,11 +1144,14 @@ " \n", " self._dnase_allcelltypes = {}\n", " for ct in self._all_celltypes:\n", + " \"\"\"\n", " dnase_filename = os.path.join(self._data_dir, '{}_dnase.npz'.format(ct))\n", " dnase_npz_contents = np.load(dnase_filename)\n", " self._dnase_allcelltypes[ct] = {}\n", " for chrom in self._all_chroms: #self._seq_bp:\n", " self._dnase_allcelltypes[ct][chrom] = dnase_npz_contents[chrom]\n", + " \"\"\"\n", + " self._dnase_allcelltypes[ct] = os.path.join(self._data_dir, 'DNASE.{}.fc.signal.bigwig'.format(ct))\n", " print(ct, time.time() - itime)\n", " \n", " # Read in metadata dataframe from training+validation data\n", @@ -733,24 +1161,22 @@ " val_df = val_regions_labeled[np.isin(val_regions_labeled['chr'], self._test_chroms)]\n", " all_df = pd.concat([training_df, val_df])\n", " \n", - " # Filter by start/stop coordinate if needed (TODO: remove for final version)\n", - " # filter_msk = all_df['start'] >= 0\n", - " # filter_msk = all_df['start']%1000 == 0\n", - " # all_df = all_df[filter_msk]\n", - " \n", + " # Get the y values, and remove ambiguous labels by default.\n", " pd_list = []\n", - " for ct in self._all_celltypes:\n", + " for ct in _all_celltypes:\n", " tc_chr = all_df[['chr', 'start', 'stop', ct]]\n", " tc_chr.columns = ['chr', 'start', 'stop', 'y']\n", + " y_array = tc_chr['y'].replace({'U': 0, 'B': 1, 'A': -1}).values\n", + "\n", + " # Now filter out ambiguous labels\n", + " non_ambig_mask = (y_array != -1)\n", + " tc_chr['y'] = y_array\n", + " tc_chr = tc_chr[non_ambig_mask]\n", + "\n", " tc_chr.insert(len(tc_chr.columns), 'celltype', ct)\n", " pd_list.append(tc_chr)\n", - " metadata_df = pd.concat(pd_list)\n", - " \n", - " # Get the y values, and remove ambiguous labels by default.\n", - " y_array = metadata_df['y'].replace({'U': 0, 'B': 1, 'A': -1}).values\n", - " non_ambig_mask = (y_array != -1)\n", - " metadata_df['y'] = y_array\n", - " self._metadata_df = metadata_df[non_ambig_mask]\n", + " print(time.time() - itime)\n", + " self._metadata_df = pd.concat(pd_list)\n", " \n", " # Downsample negatives to balance each celltype\n", " samp_ndces = []\n", @@ -814,20 +1240,30 @@ "\n", " def get_input(self, idx):\n", " \"\"\"\n", - " Returns x for a given idx.\n", + " Returns x for a given idx in metadata_array, which has been filtered to only take windows with the desired stride.\n", " Computes this from: \n", " (1) sequence features in self._seq_bp\n", - " (2) DNase features in self._dnase_allcelltypes\n", + " (2) DNase bigwig file paths in self._dnase_allcelltypes\n", " (3) Metadata for the index (location along the genome with 200bp window width)\n", " \"\"\"\n", + " \n", " this_metadata = self._metadata_df.iloc[idx, :]\n", + " \"\"\"\n", " flank_size = 400\n", " interval_start = this_metadata['start'] - flank_size\n", " interval_end = this_metadata['stop'] + flank_size\n", " dnase_this = self._dnase_allcelltypes[this_metadata['celltype']][this_metadata['chr']][interval_start:interval_end]\n", " seq_this = self._seq_bp[this_metadata['chr']][interval_start:interval_end]\n", " return torch.tensor(np.column_stack([seq_this, dnase_this]))\n", - "\n", + " \"\"\"\n", + " window_size = 12800\n", + " interval_start = this_metadata['start']\n", + " interval_end = this_metadata['stop'] + window_size\n", + " seq_this = self._seq_bp[this_metadata['chr']][interval_start:interval_end]\n", + " dnase_bw = self._dnase_allcelltypes[this_metadata['celltype']]\n", + " dnase_this = dnase_bw.values(chrom, interval_start, interval_end, numpy=True)\n", + " return torch.tensor(np.column_stack([seq_this, dnase_this]))\n", + " \n", " def eval(self, y_pred, y_true, metadata):\n", " return self.standard_group_eval(\n", " self._metric,\n", @@ -838,7 +1274,12 @@ { "cell_type": "code", "execution_count": 107, - "metadata": {}, + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + } + }, "outputs": [ { "name": "stdout", @@ -891,59 +1332,41 @@ }, { "cell_type": "code", - "execution_count": 118, - "metadata": {}, - "outputs": [], - "source": [ - "# full_dataset = copy.deepcopy(full_dataset_encode)\n", - "full_dataset = copy.deepcopy(full_dataset_camelyon17)\n", - "# full_dataset_camelyon17.split_dict" - ] - }, - { - "cell_type": "code", - "execution_count": 39, - "metadata": {}, - "outputs": [], - "source": [ - "a = np.random.choice(1210796, size=128)\n", - "seta = [full_dataset_encode.get_input(x) for x in a]\n", - "seta[0].shape" - ] - }, - { - "cell_type": "code", - "execution_count": 111, - "metadata": {}, - "outputs": [], - "source": [ - "full_dataset.metadata_fields\n", - "config = config_camelyon\n", - "#config_encode.groupby_fields\n", - "\n", - "train_grouper = CombinatorialGrouper(\n", - " dataset=full_dataset,\n", - " groupby_fields=config.groupby_fields)" - ] - }, - { - "cell_type": "code", - "execution_count": 104, + "execution_count": 2, "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 104, - "metadata": {}, - "output_type": "execute_result" + "ename": "ModuleNotFoundError", + "evalue": "No module named 'pyBigWig'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# full_dataset_encode\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mpyBigWig\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'pyBigWig'" + ] } ], "source": [ - "full_dataset" + "# full_dataset_encode\n", + "import pyBigWig" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [], + "source": [ + "a = np.random.choice(1210796, size=128)\n", + "seta = [full_dataset_encode.get_input(x) for x in a]\n", + "seta[0].shape\n", + "\n", + "# full_dataset = copy.deepcopy(full_dataset_encode)\n", + "# full_dataset = copy.deepcopy(full_dataset_camelyon17)\n", + "# full_dataset_camelyon17.split_dict\n", + "\n", + "# full_dataset" ] }, { @@ -955,7 +1378,7 @@ }, { "cell_type": "code", - "execution_count": 113, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -988,27 +1411,16 @@ " hospital = 4: n = 0\n", "Dout: 2\n" ] - }, - { - "ename": "RuntimeError", - "evalue": "CUDA error: out of memory", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 60\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 61\u001b[0m \u001b[0;31m## Initialize algorithm\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 62\u001b[0;31m algorithm = initialize_algorithm(\n\u001b[0m\u001b[1;32m 63\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 64\u001b[0m \u001b[0mdatasets\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdatasets\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/wilds/examples/algorithms/initializer.py\u001b[0m in \u001b[0;36minitialize_algorithm\u001b[0;34m(config, datasets, train_grouper)\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0malgorithm\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m'ERM'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 32\u001b[0;31m algorithm = ERM(\n\u001b[0m\u001b[1;32m 33\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[0md_out\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0md_out\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/wilds/examples/algorithms/ERM.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, config, d_out, grouper, loss, metric, n_train_steps)\u001b[0m\n\u001b[1;32m 6\u001b[0m def __init__(self, config, d_out, grouper, loss,\n\u001b[1;32m 7\u001b[0m metric, n_train_steps):\n\u001b[0;32m----> 8\u001b[0;31m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minitialize_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0md_out\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 9\u001b[0m \u001b[0;31m# initialize module\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m super().__init__(\n", - "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36mto\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 610\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_floating_point\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_blocking\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 611\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 612\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconvert\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 613\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 614\u001b[0m def register_backward_hook(\n", - "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m 357\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 358\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 359\u001b[0;31m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 360\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 361\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m 357\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 358\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 359\u001b[0;31m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 360\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 361\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m 379\u001b[0m \u001b[0;31m# `with torch.no_grad():`\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 380\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mno_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 381\u001b[0;31m \u001b[0mparam_applied\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparam\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 382\u001b[0m \u001b[0mshould_use_set_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparam\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparam_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 383\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mshould_use_set_data\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36mconvert\u001b[0;34m(t)\u001b[0m\n\u001b[1;32m 608\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mconvert_to_format\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m4\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 609\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_floating_point\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_blocking\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmemory_format\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mconvert_to_format\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 610\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_floating_point\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_blocking\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 611\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 612\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconvert\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mRuntimeError\u001b[0m: CUDA error: out of memory" - ] } ], "source": [ + "config = config_camelyon\n", + "\n", + "\n", + "train_grouper = CombinatorialGrouper(\n", + " dataset=full_dataset,\n", + " groupby_fields=config.groupby_fields)\n", + "\n", "datasets = defaultdict(dict)\n", "for split in full_dataset.split_dict.keys():\n", " if split=='train':\n", @@ -1078,188 +1490,31 @@ }, { "cell_type": "code", - "execution_count": 91, + "execution_count": 29, "metadata": {}, "outputs": [ { "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
chrstartstopycelltypesplit
3831225chr11917992501917994500H1-hESC1
4190052chr12097406002097408000H1-hESC1
7241915chr866306500663067000H1-hESC1
21449377chr238487450384876500H1-hESC0
45876013chr9569770056979000H1-hESC0
.....................
8841297chr81462777501462779501GM128782
8841298chr81462778001462780001GM128782
8841299chr81462778501462780501GM128782
8841300chr81462779001462781001GM128782
8841301chr81462779501462781501GM128782
\n", - "

1210796 rows × 6 columns

\n", - "
" - ], "text/plain": [ - " chr start stop y celltype split\n", - "3831225 chr1 191799250 191799450 0 H1-hESC 1\n", - "4190052 chr1 209740600 209740800 0 H1-hESC 1\n", - "7241915 chr8 66306500 66306700 0 H1-hESC 1\n", - "21449377 chr2 38487450 38487650 0 H1-hESC 0\n", - "45876013 chr9 5697700 5697900 0 H1-hESC 0\n", - "... ... ... ... .. ... ...\n", - "8841297 chr8 146277750 146277950 1 GM12878 2\n", - "8841298 chr8 146277800 146278000 1 GM12878 2\n", - "8841299 chr8 146277850 146278050 1 GM12878 2\n", - "8841300 chr8 146277900 146278100 1 GM12878 2\n", - "8841301 chr8 146277950 146278150 1 GM12878 2\n", - "\n", - "[1210796 rows x 6 columns]" + "" ] }, - "execution_count": 91, + "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# algorithm.device\n", - "_metadata_df\n", + "full_dataset\n", "# datasets['train']['loader']" ] }, { "cell_type": "code", - "execution_count": 90, + "execution_count": 15, "metadata": {}, - "outputs": [ - { - "ename": "NameError", - "evalue": "name 'datasets' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdatasets\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'train'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'loader'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmetadata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mNameError\u001b[0m: name 'datasets' is not defined" - ] - } - ], + "outputs": [], "source": [ "for batch in datasets['train']['loader']:\n", " x, y_true, metadata = batch\n", @@ -1268,36 +1523,70 @@ }, { "cell_type": "code", - "execution_count": 134, - "metadata": { - "collapsed": true, - "jupyter": { - "outputs_hidden": true + "execution_count": 43, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0,\n", + " 0, 1, 1, 1, 0, 0, 0, 0])" + ] + }, + "execution_count": 43, + "metadata": {}, + "output_type": "execute_result" } - }, + ], + "source": [ + "y_true" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, "outputs": [ { - "ename": "RuntimeError", - "evalue": "CUDA out of memory. Tried to allocate 14.00 MiB (GPU 0; 11.93 GiB total capacity; 10.94 GiB already allocated; 5.06 MiB free; 11.32 GiB reserved in total by PyTorch)", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# datasets['train']['dataset'].size()\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0malgorithm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0malgorithm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 725\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 728\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 729\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torchvision/models/densenet.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 190\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 191\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 192\u001b[0;31m \u001b[0mfeatures\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfeatures\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 193\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfeatures\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minplace\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 194\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madaptive_avg_pool2d\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 725\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 728\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 729\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/modules/container.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 115\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 117\u001b[0;31m \u001b[0minput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 118\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 119\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 725\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 728\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 729\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torchvision/models/densenet.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, init_features)\u001b[0m\n\u001b[1;32m 109\u001b[0m \u001b[0mfeatures\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0minit_features\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 110\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlayer\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 111\u001b[0;31m \u001b[0mnew_features\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlayer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfeatures\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 112\u001b[0m \u001b[0mfeatures\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnew_features\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfeatures\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 725\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 728\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 729\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torchvision/models/densenet.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 82\u001b[0m \u001b[0mbottleneck_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcall_checkpoint_bottleneck\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprev_features\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 83\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 84\u001b[0;31m \u001b[0mbottleneck_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbn_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprev_features\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 85\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[0mnew_features\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconv2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelu2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbottleneck_output\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torchvision/models/densenet.py\u001b[0m in \u001b[0;36mbn_function\u001b[0;34m(self, inputs)\u001b[0m\n\u001b[1;32m 39\u001b[0m \u001b[0;31m# type: (List[Tensor]) -> Tensor\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 40\u001b[0m \u001b[0mconcated_features\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 41\u001b[0;31m \u001b[0mbottleneck_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconv1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelu1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconcated_features\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# noqa: T484\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 42\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mbottleneck_output\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 725\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 728\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 729\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/modules/batchnorm.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 129\u001b[0m \u001b[0mused\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mnormalization\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m \u001b[0;32min\u001b[0m \u001b[0meval\u001b[0m \u001b[0mmode\u001b[0m \u001b[0mwhen\u001b[0m \u001b[0mbuffers\u001b[0m \u001b[0mare\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 130\u001b[0m \"\"\"\n\u001b[0;32m--> 131\u001b[0;31m return F.batch_norm(\n\u001b[0m\u001b[1;32m 132\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 133\u001b[0m \u001b[0;31m# If buffers are not to be tracked, ensure that they won't be updated\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36mbatch_norm\u001b[0;34m(input, running_mean, running_var, weight, bias, training, momentum, eps)\u001b[0m\n\u001b[1;32m 2054\u001b[0m \u001b[0m_verify_batch_size\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2055\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2056\u001b[0;31m return torch.batch_norm(\n\u001b[0m\u001b[1;32m 2057\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrunning_mean\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrunning_var\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2058\u001b[0m \u001b[0mtraining\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmomentum\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meps\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackends\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcudnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menabled\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mRuntimeError\u001b[0m: CUDA out of memory. Tried to allocate 14.00 MiB (GPU 0; 11.93 GiB total capacity; 10.94 GiB already allocated; 5.06 MiB free; 11.32 GiB reserved in total by PyTorch)" - ] + "data": { + "text/plain": [ + "tensor([[ 0.1406, -0.0628],\n", + " [ 0.0534, 0.0359],\n", + " [-0.0174, -0.0097],\n", + " [-0.0571, -0.2381],\n", + " [ 0.1590, -0.0559],\n", + " [ 0.1254, -0.0139],\n", + " [-0.0423, 0.0439],\n", + " [ 0.1621, 0.0730],\n", + " [ 0.0554, 0.0796],\n", + " [-0.0532, 0.0667],\n", + " [-0.1927, -0.0387],\n", + " [ 0.1352, -0.0385],\n", + " [-0.1320, 0.0140],\n", + " [-0.0531, -0.1171],\n", + " [-0.0378, -0.0134],\n", + " [ 0.1047, 0.0298],\n", + " [ 0.0355, -0.0497],\n", + " [ 0.1065, -0.0218],\n", + " [-0.1883, 0.1298],\n", + " [ 0.0699, -0.0875],\n", + " [-0.1233, 0.1793],\n", + " [ 0.0151, 0.0708],\n", + " [-0.0973, -0.0033],\n", + " [ 0.1027, -0.2456],\n", + " [ 0.0433, -0.0441],\n", + " [ 0.1013, -0.1020],\n", + " [ 0.1309, -0.0051],\n", + " [ 0.0028, -0.0558],\n", + " [ 0.0635, 0.0575],\n", + " [-0.0066, 0.0666],\n", + " [-0.0076, -0.0375],\n", + " [ 0.1336, 0.0024]], device='cuda:0', grad_fn=)" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ @@ -1387,13 +1676,6 @@ "outputs": [], "source": [] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, { "cell_type": "code", "execution_count": 126, @@ -1418,14 +1700,14 @@ { "cell_type": "code", "execution_count": 33, - "metadata": {}, + "metadata": { + "jupyter": { + "source_hidden": true + } + }, "outputs": [], "source": [ - "import math\n", - "import numpy as np\n", - "import torch\n", - "import torch.nn as nn\n", - "import torch.nn.functional as F\n", + "\n", "\n", "class Beagle(nn.Module):\n", " \"\"\"\n", @@ -1487,10 +1769,17 @@ }, { "cell_type": "code", - "execution_count": 100, + "execution_count": 19, "metadata": {}, "outputs": [], "source": [ + "import math\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "\n", "def double_conv(in_channels, out_channels): \n", " return nn.Sequential(\n", " nn.Conv1d(in_channels, out_channels, 7, padding=2), \n", @@ -1572,33 +1861,13 @@ }, { "cell_type": "code", - "execution_count": 101, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "485773" - ] - }, - "execution_count": 101, - "metadata": {}, - "output_type": "execute_result" + "execution_count": 20, + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true } - ], - "source": [ - "model = UNet(2)\n", - "#model = DanQ(50, 5)\n", - "\n", - "lst = [(x[0], x[1].numel()) for x in model.named_parameters()]\n", - "#np.sum([x[1] for x in lst])\n", - "count_parameters(model)" - ] - }, - { - "cell_type": "code", - "execution_count": 102, - "metadata": {}, + }, "outputs": [ { "data": { @@ -1698,43 +1967,593 @@ ")" ] }, - "execution_count": 102, + "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ + "model = UNet(2)\n", "model" ] }, { "cell_type": "code", - "execution_count": 96, + "execution_count": 101, "metadata": {}, "outputs": [ { - "ename": "NameError", - "evalue": "name 'Beagle' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mp\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparameters\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrequires_grad\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mBeagle\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0;31m#model = DanQ(50, 5)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mNameError\u001b[0m: name 'Beagle' is not defined" - ] + "data": { + "text/plain": [ + "485773" + ] + }, + "execution_count": 101, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ "def count_parameters(model):\n", - " return sum(p.numel() for p in model.parameters() if p.requires_grad)" + " return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", + "\n", + "lst = [(x[0], x[1].numel()) for x in model.named_parameters()]\n", + "#np.sum([x[1] for x in lst])\n", + "count_parameters(model)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "6955906" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "count_parameters(algorithm.model)" + ] + }, + { + "cell_type": "code", + "execution_count": 27, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "lst = [(x[0], x[1].numel()) for x in algorithm.model.named_parameters()]" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "DenseNet(\n", + " (features): Sequential(\n", + " (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n", + " (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu0): ReLU(inplace=True)\n", + " (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n", + " (denseblock1): _DenseBlock(\n", + " (denselayer1): _DenseLayer(\n", + " (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer2): _DenseLayer(\n", + " (norm1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(96, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer3): _DenseLayer(\n", + " (norm1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer4): _DenseLayer(\n", + " (norm1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(160, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer5): _DenseLayer(\n", + " (norm1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer6): _DenseLayer(\n", + " (norm1): BatchNorm2d(224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(224, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " )\n", + " (transition1): _Transition(\n", + " (norm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): ReLU(inplace=True)\n", + " (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n", + " )\n", + " (denseblock2): _DenseBlock(\n", + " (denselayer1): _DenseLayer(\n", + " (norm1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer2): _DenseLayer(\n", + " (norm1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(160, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer3): _DenseLayer(\n", + " (norm1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer4): _DenseLayer(\n", + " (norm1): BatchNorm2d(224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(224, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer5): _DenseLayer(\n", + " (norm1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer6): _DenseLayer(\n", + " (norm1): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(288, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer7): _DenseLayer(\n", + " (norm1): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer8): _DenseLayer(\n", + " (norm1): BatchNorm2d(352, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(352, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer9): _DenseLayer(\n", + " (norm1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(384, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer10): _DenseLayer(\n", + " (norm1): BatchNorm2d(416, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(416, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer11): _DenseLayer(\n", + " (norm1): BatchNorm2d(448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(448, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer12): _DenseLayer(\n", + " (norm1): BatchNorm2d(480, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(480, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " )\n", + " (transition2): _Transition(\n", + " (norm): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): ReLU(inplace=True)\n", + " (conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n", + " )\n", + " (denseblock3): _DenseBlock(\n", + " (denselayer1): _DenseLayer(\n", + " (norm1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer2): _DenseLayer(\n", + " (norm1): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(288, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer3): _DenseLayer(\n", + " (norm1): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer4): _DenseLayer(\n", + " (norm1): BatchNorm2d(352, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(352, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer5): _DenseLayer(\n", + " (norm1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(384, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer6): _DenseLayer(\n", + " (norm1): BatchNorm2d(416, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(416, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer7): _DenseLayer(\n", + " (norm1): BatchNorm2d(448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(448, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer8): _DenseLayer(\n", + " (norm1): BatchNorm2d(480, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(480, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer9): _DenseLayer(\n", + " (norm1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer10): _DenseLayer(\n", + " (norm1): BatchNorm2d(544, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(544, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer11): _DenseLayer(\n", + " (norm1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer12): _DenseLayer(\n", + " (norm1): BatchNorm2d(608, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer13): _DenseLayer(\n", + " (norm1): BatchNorm2d(640, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(640, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer14): _DenseLayer(\n", + " (norm1): BatchNorm2d(672, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(672, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer15): _DenseLayer(\n", + " (norm1): BatchNorm2d(704, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(704, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer16): _DenseLayer(\n", + " (norm1): BatchNorm2d(736, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(736, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer17): _DenseLayer(\n", + " (norm1): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(768, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer18): _DenseLayer(\n", + " (norm1): BatchNorm2d(800, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(800, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer19): _DenseLayer(\n", + " (norm1): BatchNorm2d(832, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(832, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer20): _DenseLayer(\n", + " (norm1): BatchNorm2d(864, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(864, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer21): _DenseLayer(\n", + " (norm1): BatchNorm2d(896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer22): _DenseLayer(\n", + " (norm1): BatchNorm2d(928, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(928, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer23): _DenseLayer(\n", + " (norm1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(960, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer24): _DenseLayer(\n", + " (norm1): BatchNorm2d(992, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(992, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " )\n", + " (transition3): _Transition(\n", + " (norm): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): ReLU(inplace=True)\n", + " (conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n", + " )\n", + " (denseblock4): _DenseBlock(\n", + " (denselayer1): _DenseLayer(\n", + " (norm1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer2): _DenseLayer(\n", + " (norm1): BatchNorm2d(544, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(544, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer3): _DenseLayer(\n", + " (norm1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer4): _DenseLayer(\n", + " (norm1): BatchNorm2d(608, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer5): _DenseLayer(\n", + " (norm1): BatchNorm2d(640, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(640, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer6): _DenseLayer(\n", + " (norm1): BatchNorm2d(672, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(672, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer7): _DenseLayer(\n", + " (norm1): BatchNorm2d(704, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(704, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer8): _DenseLayer(\n", + " (norm1): BatchNorm2d(736, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(736, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer9): _DenseLayer(\n", + " (norm1): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(768, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer10): _DenseLayer(\n", + " (norm1): BatchNorm2d(800, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(800, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer11): _DenseLayer(\n", + " (norm1): BatchNorm2d(832, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(832, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer12): _DenseLayer(\n", + " (norm1): BatchNorm2d(864, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(864, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer13): _DenseLayer(\n", + " (norm1): BatchNorm2d(896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer14): _DenseLayer(\n", + " (norm1): BatchNorm2d(928, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(928, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer15): _DenseLayer(\n", + " (norm1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(960, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer16): _DenseLayer(\n", + " (norm1): BatchNorm2d(992, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(992, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " )\n", + " (norm5): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " (classifier): Linear(in_features=1024, out_features=2, bias=True)\n", + ")" + ] + }, + "execution_count": 39, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "algorithm.model" + ] }, { "cell_type": "code", diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index 08cba281..04f5d08d 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -2,6 +2,7 @@ import torch import pandas as pd import numpy as np +import pyBigWig from wilds.datasets.wilds_dataset import WILDSDataset from wilds.common.grouper import CombinatorialGrouper from wilds.common.metrics.all_metrics import Accuracy @@ -76,6 +77,9 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): print(chrom, time.time() - itime) self._dnase_allcelltypes = {} + ct = 'avg' + dnase_avg_bw_path = os.path.join(self._data_dir, 'Leopard_dnase/{}.bigwig'.format(ct)) + self._dnase_allcelltypes[ct] = pyBigWig.open(dnase_avg_bw_path) for ct in self._all_celltypes: """ dnase_filename = os.path.join(self._data_dir, '{}_dnase.npz'.format(ct)) @@ -84,8 +88,8 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): for chrom in self._all_chroms: #self._seq_bp: self._dnase_allcelltypes[ct][chrom] = dnase_npz_contents[chrom] """ - self._dnase_allcelltypes[ct] = 'DNASE.{}.fc.signal.bigwig' - print(ct, time.time() - itime) + dnase_bw_path = os.path.join(self._data_dir, 'Leopard_dnase/{}.bigwig'.format(ct)) + self._dnase_allcelltypes[ct] = pyBigWig.open(dnase_bw_path) # Read in metadata dataframe from training+validation data train_regions_labeled = pd.read_csv(os.path.join(self._data_dir, 'labels/{}.train.labels.tsv.gz'.format(self._transcription_factor)), sep='\t') @@ -94,34 +98,36 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): val_df = val_regions_labeled[np.isin(val_regions_labeled['chr'], self._test_chroms)] all_df = pd.concat([training_df, val_df]) - # Filter by start/stop coordinate if needed (TODO: remove for final version) - """ - filter_msk = all_df['start'] >= 0 - filter_msk = all_df['start']%1000 == 0 - all_df = all_df[filter_msk] - """ - + # Get the y values, and remove ambiguous labels by default. pd_list = [] for ct in self._all_celltypes: tc_chr = all_df[['chr', 'start', 'stop', ct]] tc_chr.columns = ['chr', 'start', 'stop', 'y'] + y_array = tc_chr['y'].replace({'U': 0, 'B': 1, 'A': -1}).values + + # Now filter out ambiguous labels + non_ambig_mask = (y_array != -1) + tc_chr['y'] = y_array + tc_chr = tc_chr[non_ambig_mask] + tc_chr.insert(len(tc_chr.columns), 'celltype', ct) pd_list.append(tc_chr) - metadata_df = pd.concat(pd_list) - - # Get the y values, and remove ambiguous labels by default. - y_array = metadata_df['y'].replace({'U': 0, 'B': 1, 'A': -1}).values - non_ambig_mask = (y_array != -1) - metadata_df['y'] = y_array - self._metadata_df = metadata_df[non_ambig_mask] + print(time.time() - itime) + self._metadata_df = pd.concat(pd_list) + # Downsample negatives to balance each celltype samp_ndces = [] itime = time.time() - for ct in self._all_celltypes: - neg_msk = np.logical_and((self._metadata_df['celltype'] == ct), (self._metadata_df['y'] == 0)) - pos_msk = np.logical_and((self._metadata_df['celltype'] == ct), (self._metadata_df['y'] == 1)) - neg_ndces = np.where(neg_msk)[0] - pos_ndces = np.where(pos_msk)[0] + neg_msk = (self._metadata_df['y'] == 0) + pos_msk = (self._metadata_df['y'] == 1) + for ct in _all_celltypes: + celltype_msk = (self._metadata_df['celltype'] == ct) + print(ct, time.time() - itime) + neg_ct_msk = np.logical_and(celltype_msk, neg_msk) + pos_ct_msk = np.logical_and(celltype_msk, pos_msk) + print(ct, time.time() - itime) + neg_ndces = np.where(neg_ct_msk)[0] + pos_ndces = np.where(pos_ct_msk)[0] np.random.seed(42) samp_neg_ndces = np.random.choice(neg_ndces, size=len(pos_ndces), replace=False) samp_ndces.extend(samp_neg_ndces) @@ -169,21 +175,46 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): super().__init__(root_dir, download, split_scheme) + def get_random_label_vec(metadata_df, output_size=128): + # Sample a positively labeled region at random + pos_mdf = metadata_df[metadata_df['y'] == 1] #.iloc[ metadata_df['chr'] == s['chr'], : ] + pos_seed_region = pos_mdf.iloc[np.random.randint(pos_mdf.shape[0])] + + # Extract regions from this chromosome in this celltype, to get a window of labels from + chr_msk = np.array(metadata_df['chr']) == pos_seed_region['chr'] + ct_msk = np.array(metadata_df['celltype']) == pos_seed_region['celltype'] + mdf = metadata_df[chr_msk & ct_msk] + + # Get labels + start_ndx = np.where(mdf['start'] == pos_seed_region['start'])[0][0] + y_label_vec = mdf.iloc[start_ndx:start_ndx+output_size, :]['y'] + def get_input(self, idx): """ - Returns x for a given idx. + Returns x for a given idx in metadata_array, which has been filtered to only take windows with the desired stride. Computes this from: (1) sequence features in self._seq_bp - (2) DNase features in self._dnase_allcelltypes + (2) DNase bigwig file handles in self._dnase_allcelltypes (3) Metadata for the index (location along the genome with 200bp window width) """ + this_metadata = self._metadata_df.iloc[idx, :] + """ flank_size = 400 interval_start = this_metadata['start'] - flank_size interval_end = this_metadata['stop'] + flank_size dnase_this = self._dnase_allcelltypes[this_metadata['celltype']][this_metadata['chr']][interval_start:interval_end] seq_this = self._seq_bp[this_metadata['chr']][interval_start:interval_end] return torch.tensor(np.column_stack([seq_this, dnase_this])) + """ + window_size = 12800 + interval_start = this_metadata['start'] + interval_end = window_size + interval_start #this_metadata['stop'] + seq_this = self._seq_bp[this_metadata['chr']][interval_start:interval_end] + dnase_bw = self._dnase_allcelltypes[this_metadata['celltype']] + dnase_this = dnase_bw.values(chrom, interval_start, interval_end, numpy=True) + dnase_avg = self._dnase_allcelltypes['avg'].values(chrom, interval_start, interval_end, numpy=True) + return torch.tensor(np.column_stack([seq_this, dnase_this, dnase_avg])) def eval(self, y_pred, y_true, metadata): return self.standard_group_eval( From 4ace19b6e159494b10411328922dc4e5ea65a83e Mon Sep 17 00:00:00 2001 From: aikanor Date: Thu, 4 Mar 2021 20:12:23 -0800 Subject: [PATCH 027/244] final code (1/3) except eval, model fixes --- dataset_preprocessing/encode-tfbs/README.md | 8 +- .../encode-tfbs/prep_metadata_labels.ipynb | 382 ++++++ .../encode-tfbs/write_label_bigwig.py | 93 ++ examples/configs/datasets.py | 2 +- examples/models/initializer.py | 5 +- examples/sbox_run_expt.ipynb | 1142 +++++------------ wilds/datasets/encodetfbs_dataset.py | 103 +- 7 files changed, 839 insertions(+), 896 deletions(-) create mode 100644 dataset_preprocessing/encode-tfbs/prep_metadata_labels.ipynb create mode 100644 dataset_preprocessing/encode-tfbs/write_label_bigwig.py diff --git a/dataset_preprocessing/encode-tfbs/README.md b/dataset_preprocessing/encode-tfbs/README.md index bf3f92c6..7ecf1135 100644 --- a/dataset_preprocessing/encode-tfbs/README.md +++ b/dataset_preprocessing/encode-tfbs/README.md @@ -9,11 +9,11 @@ 2. Run `python prep_sequence.py --seq_path SEQUENCE_PATH --output_dir OUTPUT_DIR` to write the fasta file found in `SEQUENCE_PATH` to a numpy array archive in `OUTPUT_DIR`. -3. Download the accessibility data from the challenge. This consists of whole-genome DNase files in bigwig format (*.bw) from https://www.synapse.org/#!Synapse:syn6176233. +3. Download the DNase accessibility data. This consists of whole-genome DNase files in bigwig format from https://guanfiles.dcmb.med.umich.edu/Leopard/dnase_bigwig/. -4. Run `python prep_accessibility.py --input_dir INPUT_DIR --output_dir OUTPUT_DIR` to extract the bigwigs into numpy array archives, one per celltype. - -5. Download the labels from the challenge into a label directory created for this purpose: +4. Download the labels from the challenge into a label directory created for this purpose: - The training labels from https://www.synapse.org/#!Synapse:syn7413983 for the relevant transcription factor (e.g. https://www.synapse.org/#!Synapse:syn7415202 for the TF MAX). - The validation labels from https://www.synapse.org/#!Synapse:syn8441154 for the relevant transcription factor (e.g. https://www.synapse.org/#!Synapse:syn8442103 for the TF MAX). - (Optional) The validation labels for the challenge's evaluation cell type from https://www.synapse.org/#!Synapse:syn8442975 for the relevant transcription factor (generally primary liver cells, e.g. https://www.synapse.org/#!Synapse:syn8443021 for the TF MAX). + +5. Run `write_label_bigwig.py` diff --git a/dataset_preprocessing/encode-tfbs/prep_metadata_labels.ipynb b/dataset_preprocessing/encode-tfbs/prep_metadata_labels.ipynb new file mode 100644 index 00000000..9748bd25 --- /dev/null +++ b/dataset_preprocessing/encode-tfbs/prep_metadata_labels.ipynb @@ -0,0 +1,382 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import os, csv\n", + "import scipy, numpy as np, pandas as pd, time\n", + "from scipy import sparse\n", + "import pyBigWig\n", + "\n", + "# Human chromosome names\n", + "chr_IDs = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX']" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prep metadata df and metadata array surrounding the labels\n", + "- Metadata df contains 6400bp (window_size/2) prediction windows across the genome. Each gets a 128-bit prediction from the model.\n", + "- We store the ones that aren't fully unbound. All the rest are fully unbound." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "83.30138063430786\n", + "H1-hESC 100.73247504234314\n", + "HCT116 106.4023334980011\n", + "HeLa-S3 111.88021206855774\n", + "HepG2 117.56940197944641\n", + "K562 126.93423342704773\n", + "A549 138.21517205238342\n", + "GM12878 148.77391648292542\n", + "150.62964010238647\n", + "213.72714066505432\n" + ] + } + ], + "source": [ + "itime = time.time()\n", + "\n", + "_data_dir = '../../examples/data/encode-tfbs_v1.0/'\n", + "_transcription_factor = 'MAX'\n", + "_train_chroms = ['chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr10', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX']\n", + "_val_chroms = ['chr2', 'chr9', 'chr11']\n", + "_test_chroms = ['chr1', 'chr8', 'chr21']\n", + "_all_chroms = _train_chroms + _val_chroms + _test_chroms\n", + "_train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']\n", + "_val_celltype = ['A549']\n", + "_test_celltype = ['GM12878']\n", + "_all_celltypes = _train_celltypes + _val_celltype + _test_celltype\n", + "\n", + "# Read in metadata dataframe from training+validation data\n", + "train_regions_labeled = pd.read_csv(os.path.join(_data_dir, 'labels/{}.train.labels.tsv.gz'.format(_transcription_factor)), sep='\\t')\n", + "val_regions_labeled = pd.read_csv(os.path.join(_data_dir, 'labels/{}.val.labels.tsv.gz'.format(_transcription_factor)), sep='\\t')\n", + "training_df = train_regions_labeled# [np.isin(train_regions_labeled['chr'], _train_chroms)]\n", + "val_df = val_regions_labeled# [np.isin(val_regions_labeled['chr'], _test_chroms)]\n", + "all_df = pd.concat([training_df, val_df])\n", + "\n", + "print(time.time() - itime)\n", + "\n", + "# Get the y values, and remove labels by default.\n", + "pd_list = []\n", + "for ct in _all_celltypes:\n", + " tc_chr = all_df[['chr', 'start', 'stop', ct]]\n", + " tc_chr.columns = ['chr', 'start', 'stop', 'y']\n", + " tc_chr = tc_chr[tc_chr['y'] != 'U']\n", + " tc_chr['y'] = tc_chr['y'].replace({'U': 0, 'B': 1, 'A': 0.5}).values\n", + " \n", + " tc_chr.insert(len(tc_chr.columns), 'celltype', ct)\n", + " pd_list.append(tc_chr)\n", + " print(ct, time.time() - itime)\n", + "_metadata_df = pd.concat(pd_list)\n", + "\n", + "print(time.time() - itime)\n", + "_unsorted_dir = _data_dir + 'labels/MAX/MAX_posamb.bed'\n", + "_sorted_dir = _unsorted_dir.replace('MAX_posamb', 'MAX_posamb.sorted')\n", + "_metadata_df.to_csv(\n", + " _unsorted_dir, sep='\\t', header=False, index=False\n", + ")\n", + "print(time.time() - itime)\n", + "\n", + "os.system('sort -k1,1 -k2,2n {} > {}'.format(_unsorted_dir, _sorted_dir))\n", + "\n", + "mdf_posamb = pd.read_csv(\n", + " _sorted_dir, \n", + " sep='\\t', header=None, index_col=None, names=['chr', 'start', 'stop', 'y', 'celltype']\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "H1-hESC 350.84476041793823\n", + "HCT116 358.2693498134613\n", + "HeLa-S3 364.6210968494415\n", + "HepG2 372.65956830978394\n", + "K562 380.6701240539551\n", + "A549 388.50364875793457\n", + "GM12878 394.2549338340759\n" + ] + } + ], + "source": [ + "chrom_sizes = {'chr1': 249250621, 'chr10': 135534747, 'chr11': 135006516, 'chr12': 133851895, 'chr13': 115169878, 'chr14': 107349540, 'chr15': 102531392, 'chr16': 90354753, 'chr17': 81195210, 'chr18': 78077248, 'chr19': 59128983, 'chr2': 243199373, 'chr20': 63025520, 'chr21': 48129895, 'chr22': 51304566, 'chr3': 198022430, 'chr4': 191154276, 'chr5': 180915260, 'chr6': 171115067, 'chr7': 159138663, 'chr8': 146364022, 'chr9': 141213431, 'chrX': 155270560}\n", + "chromsizes_list = [(k, v) for k, v in chrom_sizes.items()]\n", + "for ct in _all_celltypes:\n", + " ct_labels_bw_path = _data_dir + \"labels/MAX/MAX_{}.bigwig\".format(ct)\n", + " df = mdf_posamb[mdf_posamb['celltype'] == ct]\n", + " bw = pyBigWig.open(ct_labels_bw_path, \"w\")\n", + " bw.addHeader(chromsizes_list)\n", + " bw.addEntries(list(df['chr']), list(df['start']), ends=list(df['start']+50), values=list(df['y']))\n", + " print(ct, time.time() - itime)\n", + " bw.close()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + ":9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " df['window_start'] = stride*(df['start'] // stride)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "A549 63.97912120819092\n", + "GM12878 103.89278292655945\n", + "H1-hESC 182.84059262275696\n", + "HCT116 243.95744681358337\n", + "HeLa-S3 303.7187397480011\n", + "HepG2 375.8099205493927\n", + "K562 456.08897161483765\n", + "456.0923991203308\n", + "462.8749210834503\n" + ] + } + ], + "source": [ + "stride = 6400\n", + "itime = time.time()\n", + "celltype_mdta = []\n", + "celltype_labels = []\n", + "\n", + "for ct in _all_celltypes:\n", + " ct_labels_bw_path = _data_dir + \"labels/MAX/MAX_{}.bigwig\".format(ct)\n", + " df = mdf_posamb[mdf_posamb['celltype'] == ct]\n", + " df['window_start'] = stride*(df['start'] // stride)\n", + " uniq_windows = np.unique([\"{}:{}\".format(x[0], x[1]) for x in zip(df['chr'], df['window_start'])])\n", + " df_construction = []\n", + " mdta_labels = []\n", + " \n", + " bw = pyBigWig.open(ct_labels_bw_path)\n", + " num_reps = 0\n", + " for u in uniq_windows:\n", + " u_chr = u.split(':')[0]\n", + " u_start = int(u.split(':')[1])\n", + " u_end = u_start + stride\n", + " x = np.nan_to_num(bw.values(u_chr, u_start, u_end, numpy=True))\n", + " df_construction.append((u_chr, u_start, u_end))\n", + " mdta_labels.append(x[np.arange(0, len(x), 50)])\n", + " num_reps = num_reps + 1\n", + " celltype_mdta_df = pd.DataFrame(df_construction, columns=['chr', 'start', 'stop'])\n", + " celltype_mdta_df.insert(len(celltype_mdta_df.columns), 'celltype', ct)\n", + " celltype_mdta.append(celltype_mdta_df)\n", + " celltype_labels.append(np.stack(mdta_labels))\n", + " print(ct, time.time() - itime)\n", + " bw.close()\n", + " # break\n", + "print(time.time() - itime)\n", + "# _metadata_df\n", + "\n", + "pd.concat(celltype_mdta).to_csv(\n", + " _data_dir + 'labels/MAX/metadata_df.bed', \n", + " sep='\\t', header=False, index=False\n", + ")\n", + "np.save(_data_dir + 'labels/MAX/metadata_y.npy', np.vstack(celltype_labels))\n", + "print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
chrstartstopcelltype
0chr10100025600100032000A549
1chr10100032000100038400A549
2chr10100064000100070400A549
3chr10100076800100083200A549
4chr10100083200100089600A549
...............
523753chrX9969920099705600K562
523754chrX99776009984000K562
523755chrX9990400099910400K562
523756chrX9992320099929600K562
523757chrX99993600100000000K562
\n", + "

523758 rows × 4 columns

\n", + "
" + ], + "text/plain": [ + " chr start stop celltype\n", + "0 chr10 100025600 100032000 A549\n", + "1 chr10 100032000 100038400 A549\n", + "2 chr10 100064000 100070400 A549\n", + "3 chr10 100076800 100083200 A549\n", + "4 chr10 100083200 100089600 A549\n", + "... ... ... ... ...\n", + "523753 chrX 99699200 99705600 K562\n", + "523754 chrX 9977600 9984000 K562\n", + "523755 chrX 99904000 99910400 K562\n", + "523756 chrX 99923200 99929600 K562\n", + "523757 chrX 99993600 100000000 K562\n", + "\n", + "[523758 rows x 4 columns]" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pd.read_csv(\n", + " _data_dir + 'labels/MAX/metadata_df.bed', sep='\\t', header=None, \n", + " index_col=None, names=['chr', 'start', 'stop', 'celltype']\n", + ")\n", + "# np.load(_data_dir + 'labels/MAX/metadata_y.npy')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/dataset_preprocessing/encode-tfbs/write_label_bigwig.py b/dataset_preprocessing/encode-tfbs/write_label_bigwig.py new file mode 100644 index 00000000..8dcf4f3e --- /dev/null +++ b/dataset_preprocessing/encode-tfbs/write_label_bigwig.py @@ -0,0 +1,93 @@ +import argparse, time +import numpy as np +import pyBigWig + +# Human hg19 chromosome names/lengths +chrom_sizes = {'chr1': 249250621, 'chr10': 135534747, 'chr11': 135006516, 'chr12': 133851895, 'chr13': 115169878, 'chr14': 107349540, 'chr15': 102531392, 'chr16': 90354753, 'chr17': 81195210, 'chr18': 78077248, 'chr19': 59128983, 'chr2': 243199373, 'chr20': 63025520, 'chr21': 48129895, 'chr22': 51304566, 'chr3': 198022430, 'chr4': 191154276, 'chr5': 180915260, 'chr6': 171115067, 'chr7': 159138663, 'chr8': 146364022, 'chr9': 141213431, 'chrX': 155270560} + +celltypes = ['A549', 'GM12878', 'H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562'] + +chr_IDs = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX'] + +def write_label_bigwig( + metadata_df, output_dir='codalab_archive' +): + dnases = {} + for ctype in celltypes: + itime = time.time() + bw = pyBigWig.open("{}/DNASE.{}.fc.signal.bigwig".format(input_dir, ctype)) + chromsizes = bw.chroms() + dn_dict = {} + for chrom in chromsizes: #chr_IDs: + x = bw.values(chrom, 0, chromsizes[chrom], numpy=True) + # half-precision makes things significantly smaller (less time to load) + dn_dict[chrom] = np.nan_to_num(x).astype(np.float16) + print("{}, {}. Time: {}".format(ctype, chrom, time.time() - itime)) + dnases[ctype] = dn_dict + + for ctype in dnases: + itime = time.time() + dn_dict = dnases[ctype] + + # Save as npz archive + np.savez_compressed('{}/{}_dnase'.format(output_dir, ctype), **dn_dict) + print("Saving npz archive for celltype {}. Time: {}".format(ctype, time.time() - itime)) + + +if __name__ == '__main__': + itime = time.time() + _data_dir = '../../examples/data/encode-tfbs_v1.0/' + _transcription_factor = 'MAX' + _train_chroms = ['chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr10', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX'] + _val_chroms = ['chr2', 'chr9', 'chr11'] + _test_chroms = ['chr1', 'chr8', 'chr21'] + _all_chroms = _train_chroms + _val_chroms + _test_chroms + _train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562'] + _val_celltype = ['A549'] + _test_celltype = ['GM12878'] + _all_celltypes = _train_celltypes + _val_celltype + _test_celltype + + # Read in metadata dataframe from training+validation data + train_regions_labeled = pd.read_csv(os.path.join(_data_dir, 'labels/{}.train.labels.tsv.gz'.format(_transcription_factor)), sep='\t') + val_regions_labeled = pd.read_csv(os.path.join(_data_dir, 'labels/{}.val.labels.tsv.gz'.format(_transcription_factor)), sep='\t') + training_df = train_regions_labeled# [np.isin(train_regions_labeled['chr'], _train_chroms)] + val_df = val_regions_labeled# [np.isin(val_regions_labeled['chr'], _test_chroms)] + all_df = pd.concat([training_df, val_df]) + + print(time.time() - itime) + + # Get the y values, and remove labels by default. + pd_list = [] + for ct in _all_celltypes: + tc_chr = all_df[['chr', 'start', 'stop', ct]] + tc_chr.columns = ['chr', 'start', 'stop', 'y'] + tc_chr = tc_chr[tc_chr['y'] != 'U'] + tc_chr['y'] = tc_chr['y'].replace({'U': 0, 'B': 1, 'A': 0.5}).values + tc_chr.insert(len(tc_chr.columns), 'celltype', ct) + pd_list.append(tc_chr) + print(ct, time.time() - itime) + _metadata_df = pd.concat(pd_list) + + print(time.time() - itime) + _unsorted_dir = _data_dir + 'labels/MAX/MAX_posamb.bed' + _sorted_dir = _unsorted_dir.replace('MAX_posamb', 'MAX_posamb.sorted') + _metadata_df.to_csv( + _unsorted_dir, sep='\t', header=False, index=False + ) + print(time.time() - itime) + + os.system('sort -k1,1 -k2,2n {} > {}'.format(_unsorted_dir, _sorted_dir)) + + mdf_posamb = pd.read_csv( + _sorted_dir, + sep='\t', header=None, index_col=None, names=['chr', 'start', 'stop', 'y', 'celltype'] + ) + chromsizes_list = [(k, v) for k, v in chrom_sizes.items()] + for ct in _all_celltypes: + ct_labels_bw_path = _data_dir + "labels/MAX/MAX_{}.bigwig".format(ct) + df = mdf_posamb[mdf_posamb['celltype'] == ct] + bw = pyBigWig.open(ct_labels_bw_path, "w") + bw.addHeader(chromsizes_list) + bw.addEntries(list(df['chr']), list(df['start']), ends=list(df['start']+50), values=list(df['y'])) + print(ct, time.time() - itime) + bw.close() diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index c4b900fd..07219823 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -96,7 +96,7 @@ 'train_transform': None, 'eval_transform': None, 'loss_function': 'cross_entropy', - 'groupby_fields': ['celltype', 'y'], + 'groupby_fields': ['celltype'], 'val_metric': 'acc_avg', 'val_metric_decreasing': False, 'optimizer': 'Adam', diff --git a/examples/models/initializer.py b/examples/models/initializer.py index 5de63c54..e37a4250 100644 --- a/examples/models/initializer.py +++ b/examples/models/initializer.py @@ -4,6 +4,7 @@ from models.resnet_multispectral import ResNet18 from models.layers import Identity from models.gnn import GINVirtual +from models.CNN_genome import UNet def initialize_model(config, d_out): print('Dout: {}'.format(d_out)) @@ -23,12 +24,12 @@ def initialize_model(config, d_out): config.model, num_labels=d_out, **config.model_kwargs) + elif config.model == 'leopard': + model = UNet(d_out) elif config.model == 'logistic_regression': model = nn.Linear(out_features=d_out, **config.model_kwargs) elif config.model == 'gin-virtual': model = GINVirtual(num_tasks=d_out, **config.model_kwargs) - # elif config.model == 'leopard': - # model = GINVirtual(num_tasks=d_out, **config.model_kwargs) else: raise ValueError('Model not recognized.') return model diff --git a/examples/sbox_run_expt.ipynb b/examples/sbox_run_expt.ipynb index 06440dc6..071a68d7 100644 --- a/examples/sbox_run_expt.ipynb +++ b/examples/sbox_run_expt.ipynb @@ -33,43 +33,61 @@ { "cell_type": "code", "execution_count": 1, - "metadata": {}, + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + } + }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "163 µs ± 343 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n" + "ename": "NameError", + "evalue": "name 'bw' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# import pyBigWig\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;31m# %timeit bw = pyBigWig.open(\"/users/abalsubr/wilds/examples/data/encode-tfbs_v1.0/DNASE.K562.fc.signal.bigwig\")\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mget_ipython\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_line_magic\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'timeit'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"bw.values('chr1', 10000, 22800, numpy=True)\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/site-packages/IPython/core/interactiveshell.py\u001b[0m in \u001b[0;36mrun_line_magic\u001b[0;34m(self, magic_name, line, _stack_depth)\u001b[0m\n\u001b[1;32m 2334\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'local_ns'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_local_scope\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstack_depth\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2335\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbuiltin_trap\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2336\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2337\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2338\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m\u001b[0m in \u001b[0;36mtimeit\u001b[0;34m(self, line, cell, local_ns)\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/site-packages/IPython/core/magic.py\u001b[0m in \u001b[0;36m\u001b[0;34m(f, *a, **k)\u001b[0m\n\u001b[1;32m 185\u001b[0m \u001b[0;31m# but it's overkill for just that one bit of state.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 186\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mmagic_deco\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 187\u001b[0;31m \u001b[0mcall\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 188\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 189\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mcallable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/site-packages/IPython/core/magics/execution.py\u001b[0m in \u001b[0;36mtimeit\u001b[0;34m(self, line, cell, local_ns)\u001b[0m\n\u001b[1;32m 1167\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mindex\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1168\u001b[0m \u001b[0mnumber\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m10\u001b[0m \u001b[0;34m**\u001b[0m \u001b[0mindex\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1169\u001b[0;31m \u001b[0mtime_number\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtimer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtimeit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnumber\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1170\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtime_number\u001b[0m \u001b[0;34m>=\u001b[0m \u001b[0;36m0.2\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1171\u001b[0m \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/site-packages/IPython/core/magics/execution.py\u001b[0m in \u001b[0;36mtimeit\u001b[0;34m(self, number)\u001b[0m\n\u001b[1;32m 167\u001b[0m \u001b[0mgc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdisable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 168\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 169\u001b[0;31m \u001b[0mtiming\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minner\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mit\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtimer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 170\u001b[0m \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 171\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mgcold\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m\u001b[0m in \u001b[0;36minner\u001b[0;34m(_it, _timer)\u001b[0m\n", + "\u001b[0;31mNameError\u001b[0m: name 'bw' is not defined" ] } ], "source": [ - "import pyBigWig\n", - "# %timeit bw = pyBigWig.open(\"/users/abalsubr/wilds/examples/data/encode-tfbs_v1.0/DNASE.K562.fc.signal.bigwig\")" + "# import pyBigWig\n", + "# %timeit bw = pyBigWig.open(\"/users/abalsubr/wilds/examples/data/encode-tfbs_v1.0/DNASE.K562.fc.signal.bigwig\")\n", + "%timeit bw.values('chr1', 10000, 22800, numpy=True)" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "WARNING:root:The OGB package is out of date. Your version is 1.2.4, while the latest version is 1.2.5.\n" + "WARNING:root:The OGB package is out of date. Your version is 1.2.4, while the latest version is 1.2.6.\n" ] } ], "source": [ - "import os, csv\n", + "import os, csv, sys\n", + "os.environ['CUDA_VISIBLE_DEVICES'] = '4'\n", + "\n", "import time\n", "import argparse\n", "import numpy as np, pandas as pd\n", "import torch\n", "import torch.nn as nn\n", "import torchvision\n", - "import sys\n", + "import pyBigWig\n", "from collections import defaultdict\n", "\n", "from wilds.common.data_loaders import get_train_loader, get_eval_loader\n", @@ -85,16 +103,16 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" + "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" ] }, - "execution_count": 3, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } @@ -189,7 +207,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -207,7 +225,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -236,7 +254,7 @@ "Resize scale: None\n", "Max token length: None\n", "Loss function: cross_entropy\n", - "Groupby fields: ['celltype', 'y']\n", + "Groupby fields: ['celltype']\n", "Group dro step size: None\n", "Coral penalty weight: None\n", "Irm lambda: None\n", @@ -270,41 +288,19 @@ "Progress bar: False\n", "Resume: False\n", "\n", - "chr2 3.6633927822113037\n", - "chr3 6.7115819454193115\n", - "chr4 9.648637771606445\n", - "chr5 12.439441919326782\n", - "chr6 15.091757774353027\n", - "chr7 17.542895555496216\n", - "chr9 19.707583904266357\n", - "chr10 21.79905652999878\n", - "chr11 23.86957049369812\n", - "chr12 25.918642044067383\n", - "chr13 27.675577402114868\n", - "chr14 29.3148353099823\n", - "chr15 30.881144046783447\n", - "chr16 32.271193504333496\n", - "chr17 33.51785063743591\n", - "chr18 34.72123050689697\n", - "chr19 35.627156257629395\n", - "chr20 36.59872794151306\n", - "chr22 37.37847852706909\n", - "chrX 39.77280807495117\n", - "chr1 43.60475468635559\n", - "chr8 45.86070203781128\n", - "chr21 46.59553360939026\n" - ] - }, - { - "ename": "NameError", - "evalue": "name '_all_celltypes' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0;31m# Data\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 26\u001b[0;31m full_dataset = supported.datasets[config.dataset](\n\u001b[0m\u001b[1;32m 27\u001b[0m \u001b[0mroot_dir\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mroot_dir\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0mdownload\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdownload\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/wilds/wilds/datasets/encodetfbs_dataset.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, root_dir, download, split_scheme)\u001b[0m\n\u001b[1;32m 101\u001b[0m \u001b[0;31m# Get the y values, and remove ambiguous labels by default.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 102\u001b[0m \u001b[0mpd_list\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 103\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mct\u001b[0m \u001b[0;32min\u001b[0m \u001b[0m_all_celltypes\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 104\u001b[0m \u001b[0mtc_chr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mall_df\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'chr'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'start'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'stop'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mct\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 105\u001b[0m \u001b[0mtc_chr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcolumns\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m'chr'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'start'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'stop'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'y'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mNameError\u001b[0m: name '_all_celltypes' is not defined" + "chr3 5.088634967803955\n", + "chr4 9.974164009094238\n", + "chr5 15.149149894714355\n", + "chr6 19.728455066680908\n", + "chr7 23.769655466079712\n", + "chr10 29.31521511077881\n", + "chr12 32.78225326538086\n", + "chr13 35.67028570175171\n", + "chr14 46.721638441085815\n", + "chr15 92.16564106941223\n", + "chr16 96.26218318939209\n", + "chr17 114.85105729103088\n", + "chr18 116.09504199028015\n" ] } ], @@ -354,20 +350,13 @@ }, { "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" + "execution_count": 5, + "metadata": { + "jupyter": { + "source_hidden": true } - ], + }, + "outputs": [], "source": [ "import copy\n", "full_dataset_camelyon17 = copy.deepcopy(full_dataset)\n", @@ -376,56 +365,58 @@ "# print(config_camelyon.train_transform, config_encode.train_transform)\n" ] }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "ename": "NameError", - "evalue": "name 'full_dataset' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mfull_dataset\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;31mNameError\u001b[0m: name 'full_dataset' is not defined" - ] - } - ], - "source": [ - "full_dataset" - ] - }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## 2) Initialize dataset object" + "## 2) Initialize dataset object (trial version)" ] }, { "cell_type": "code", "execution_count": 8, - "metadata": {}, + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true, + "source_hidden": true + } + }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "chr2 3.90254282951355\n", - "chr9 6.149690628051758\n", - "chr11 8.327073097229004\n", - "chr1 12.291624546051025\n", - "chr8 14.624409675598145\n", - "chr21 15.413429021835327\n", - "H1-hESC 15.415196895599365\n", - "HCT116 15.415941953659058\n", - "HeLa-S3 15.416455030441284\n", - "HepG2 15.417592763900757\n", - "K562 15.418397426605225\n", - "A549 15.41891360282898\n", - "GM12878 15.419732332229614\n" + "chr3 3.0872416496276855\n", + "chr4 6.014077425003052\n", + "chr5 8.789116859436035\n", + "chr6 11.409600496292114\n", + "chr7 13.844907283782959\n", + "chr10 15.919893264770508\n", + "chr12 17.969276189804077\n", + "chr13 19.71941637992859\n", + "chr14 21.34366464614868\n", + "chr15 22.900768995285034\n", + "chr16 24.27766728401184\n", + "chr17 25.519333600997925\n", + "chr18 26.714667797088623\n", + "chr19 27.614336490631104\n", + "chr20 28.57899522781372\n", + "chr22 29.353068113327026\n", + "chrX 31.731130599975586\n", + "chr2 35.449124813079834\n", + "chr9 37.5920934677124\n", + "chr11 39.65406608581543\n", + "chr1 43.44736051559448\n", + "chr8 45.68234419822693\n", + "chr21 46.41120982170105\n", + "H1-hESC 46.41424226760864\n", + "HCT116 46.41492676734924\n", + "HeLa-S3 46.41563010215759\n", + "HepG2 46.41687893867493\n", + "K562 46.41777992248535\n", + "A549 46.41860294342041\n", + "GM12878 46.41955780982971\n" ] } ], @@ -446,18 +437,18 @@ "_dataset_name = 'encode-tfbs'\n", "_version = '1.0'\n", "_download_url = 'https://worksheets.codalab.org/rest/bundles/0x8b3255e21e164cd98d3aeec09cd0bc26/contents/blob/'\n", - "_data_dir = 'data/encode-tfbs_v1.0'\n", + "_data_dir = 'data/encode-tfbs_v1.0/'\n", "_y_size = 1\n", "_n_classes = 2\n", "\n", - "# _train_chroms = ['chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX']\n", - "_train_chroms = ['chr2', 'chr9', 'chr11']\n", + "_train_chroms = ['chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr10', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX']\n", + "_val_chroms = ['chr2', 'chr9', 'chr11']\n", "_test_chroms = ['chr1', 'chr8', 'chr21']\n", "_transcription_factor = 'MAX'\n", "_train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']\n", "_val_celltype = ['A549']\n", "_test_celltype = ['GM12878']\n", - "_all_chroms = _train_chroms + _test_chroms\n", + "_all_chroms = _train_chroms + _val_chroms + _test_chroms\n", "_all_celltypes = _train_celltypes + _val_celltype + _test_celltype\n", "\n", "_metadata_map = {}\n", @@ -486,7 +477,7 @@ "sequence_filename = os.path.join(_data_dir, 'sequence.npz')\n", "seq_arr = np.load(sequence_filename)\n", "_seq_bp = {}\n", - "for chrom in _all_chroms: #seq_arr:\n", + "for chrom in _all_chroms:\n", " _seq_bp[chrom] = seq_arr[chrom]\n", " print(chrom, time.time() - itime)\n", "\n", @@ -504,533 +495,75 @@ " \"\"\"\n", " dnase_bw_path = os.path.join(_data_dir, 'Leopard_dnase/{}.bigwig'.format(ct))\n", " _dnase_allcelltypes[ct] = pyBigWig.open(dnase_bw_path)\n", - " print(ct, time.time() - itime)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "74.06488299369812\n" - ] - } - ], - "source": [ - "itime = time.time()\n", - "\n", - "# Read in metadata dataframe from training+validation data\n", - "train_regions_labeled = pd.read_csv(os.path.join(_data_dir, 'labels/{}.train.labels.tsv.gz'.format(_transcription_factor)), sep='\\t')\n", - "val_regions_labeled = pd.read_csv(os.path.join(_data_dir, 'labels/{}.val.labels.tsv.gz'.format(_transcription_factor)), sep='\\t')\n", - "training_df = train_regions_labeled[np.isin(train_regions_labeled['chr'], _train_chroms)]\n", - "val_df = val_regions_labeled[np.isin(val_regions_labeled['chr'], _test_chroms)]\n", - "all_df = pd.concat([training_df, val_df])\n", - "\n", - "print(time.time() - itime)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - ":8: SettingWithCopyWarning: \n", - "A value is trying to be set on a copy of a slice from a DataFrame.\n", - "Try using .loc[row_indexer,col_indexer] = value instead\n", - "\n", - "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", - " tc_chr['y'] = tc_chr['y'].replace({'U': 0, 'B': 1, 'A': 0.5}).values\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "13.203907012939453\n", - "22.746795415878296\n", - "32.076903104782104\n", - "41.43525505065918\n", - "51.37267017364502\n", - "61.07364773750305\n", - "71.10506510734558\n", - "100.72125267982483\n" - ] - } - ], - "source": [ - "itime = time.time()\n", - "\n", - "# Get the y values, and remove ambiguous labels by default.\n", - "pd_list = []\n", - "for ct in _all_celltypes:\n", - " tc_chr = all_df[['chr', 'start', 'stop', ct]]\n", - " tc_chr.columns = ['chr', 'start', 'stop', 'y']\n", - " tc_chr['y'] = tc_chr['y'].replace({'U': 0, 'B': 1, 'A': 0.5}).values\n", - " \n", - " # # Now filter out ambiguous labels\n", - " # non_ambig_mask = (y_array != -1)\n", - " # tc_chr['y'] = y_array\n", - " # tc_chr = tc_chr[non_ambig_mask]\n", - " \n", - " tc_chr.insert(len(tc_chr.columns), 'celltype', ct)\n", - " pd_list.append(tc_chr)\n", - " print(time.time() - itime)\n", - "metadata_df = pd.concat(pd_list)\n", - "\n", - "print(time.time() - itime)\n", + " print(ct, time.time() - itime)\n", "\n", - "_metadata_df = metadata_df\n" + "_metadata_df = pd.read_csv(\n", + " _data_dir + 'labels/MAX/metadata_df.bed', sep='\\t', header=None, \n", + " index_col=None, names=['chr', 'start', 'stop', 'celltype']\n", + ")" ] }, { "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "H1-hESC 8.80908489227295\n", - "H1-hESC 12.474784135818481\n", - "H1-hESC 15.258162498474121\n", - "HCT116 23.023517370224\n", - "HCT116 25.439095735549927\n", - "HCT116 27.21790099143982\n", - "HeLa-S3 34.81065845489502\n", - "HeLa-S3 35.2776780128479\n", - "HeLa-S3 36.60090255737305\n", - "HepG2 44.36072087287903\n", - "HepG2 44.74501991271973\n", - "HepG2 46.02569603919983\n", - "K562 53.825233697891235\n", - "K562 54.182188749313354\n", - "K562 55.44522547721863\n", - "A549 62.980581283569336\n", - "A549 63.34522008895874\n", - "A549 64.59721446037292\n", - "GM12878 72.41460752487183\n", - "GM12878 72.7955391407013\n", - "GM12878 74.05369997024536\n" - ] + "execution_count": 325, + "metadata": { + "jupyter": { + "source_hidden": true } - ], - "source": [ - "# np.unique(_metadata_df['y'])\n", - "\n", - "# Downsample negatives to balance each celltype\n", - "samp_ndces = []\n", - "itime = time.time()\n", - "neg_msk = (_metadata_df['y'] == 0)\n", - "pos_msk = (_metadata_df['y'] != 0)\n", - "for ct in _all_celltypes:\n", - " celltype_msk = (_metadata_df['celltype'] == ct)\n", - " print(ct, time.time() - itime)\n", - " neg_ct_msk = np.logical_and(celltype_msk, neg_msk)\n", - " pos_ct_msk = np.logical_and(celltype_msk, pos_msk)\n", - " print(ct, time.time() - itime)\n", - " neg_ndces = np.where(neg_ct_msk)[0]\n", - " pos_ndces = np.where(pos_ct_msk)[0]\n", - " np.random.seed(42)\n", - " samp_neg_ndces = np.random.choice(neg_ndces, size=len(pos_ndces), replace=False)\n", - " samp_ndces.extend(samp_neg_ndces)\n", - " samp_ndces.extend(pos_ndces)\n", - " print(ct, time.time() - itime)\n", - "_metadata_df = _metadata_df.iloc[samp_ndces, :]" - ] - }, - { - "cell_type": "code", - "execution_count": 145, - "metadata": {}, + }, "outputs": [], "source": [ - "def get_random_label_vec(metadata_df, output_size=128):\n", - " # Sample a positively labeled region at random\n", - " pos_mdf = metadata_df[metadata_df['y'] == 1] #.iloc[ metadata_df['chr'] == s['chr'], : ]\n", - " pos_seed_region = pos_mdf.iloc[np.random.randint(pos_mdf.shape[0])]\n", - "\n", + "def get_random_label_vec(\n", + " metadata_df, seed_chr, seed_celltype, seed_start, output_size=128\n", + "):\n", + " \"\"\"\n", + " Given a coordinate in a celltype, gets the labels of \n", + " the `output_size` 200bp bins from that coordinate onward. \n", + " \"\"\"\n", + " itime = time.time()\n", + " \n", " # Extract regions from this chromosome in this celltype, to get a window of labels from\n", - " chr_msk = np.array(metadata_df['chr']) == pos_seed_region['chr']\n", - " ct_msk = np.array(metadata_df['celltype']) == pos_seed_region['celltype']\n", - " mdf = metadata_df[chr_msk & ct_msk]\n", + " # print(time.time() - itime)\n", + " # chr_msk = np.array(metadata_df['chr']) == seed_region['chr']\n", + " # print(time.time() - itime)\n", + " # ct_msk = np.array(metadata_df['celltype']) == seed_region['celltype']\n", + " # mdf = metadata_df[chr_msk & ct_msk]\n", + " seq_size = output_size*50\n", + " mdf = metadata_df.loc[\n", + " (metadata_df['chr'] == seed_chr) & \n", + " (metadata_df['celltype'] == seed_celltype) & \n", + " (metadata_df['start'] >= seed_start) & \n", + " (metadata_df['stop'] < seed_start+seq_size)\n", + " ]\n", + " print(time.time() - itime)\n", "\n", " # Get labels\n", - " start_ndx = np.where(mdf['start'] == pos_seed_region['start'])[0][0]\n", - " y_label_vec = mdf.iloc[start_ndx:start_ndx+output_size, :]['y']" + " y_label_vec = np.zeros(output_size)\n", + " y_label_vec[(mdf['start'] - seed_start) // 50] = mdf['y']\n", + " return mdf, y_label_vec" ] }, { "cell_type": "code", - "execution_count": 146, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
chrstartstopycelltype
5937924chr86008000.0HepG2
5937925chr86508500.0HepG2
5937926chr87009000.0HepG2
5937927chr87509500.0HepG2
5937928chr880010000.0HepG2
..................
8843006chr81463632001463634000.0HepG2
8843007chr81463632501463634500.0HepG2
8843008chr81463633001463635000.0HepG2
8843009chr81463633501463635500.0HepG2
8843010chr81463634001463636000.0HepG2
\n", - "

2905087 rows × 5 columns

\n", - "
" - ], - "text/plain": [ - " chr start stop y celltype\n", - "5937924 chr8 600 800 0.0 HepG2\n", - "5937925 chr8 650 850 0.0 HepG2\n", - "5937926 chr8 700 900 0.0 HepG2\n", - "5937927 chr8 750 950 0.0 HepG2\n", - "5937928 chr8 800 1000 0.0 HepG2\n", - "... ... ... ... ... ...\n", - "8843006 chr8 146363200 146363400 0.0 HepG2\n", - "8843007 chr8 146363250 146363450 0.0 HepG2\n", - "8843008 chr8 146363300 146363500 0.0 HepG2\n", - "8843009 chr8 146363350 146363550 0.0 HepG2\n", - "8843010 chr8 146363400 146363600 0.0 HepG2\n", - "\n", - "[2905087 rows x 5 columns]" - ] - }, - "execution_count": 146, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 154, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "8809571 1.0\n", - "8809572 1.0\n", - "8809573 1.0\n", - "8809574 1.0\n", - "8809575 1.0\n", - "8809576 0.5\n", - "8809577 0.5\n", - "8809578 0.5\n", - "8809579 1.0\n", - "8809580 1.0\n", - "8809581 1.0\n", - "8809582 1.0\n", - "8809583 1.0\n", - "8809584 1.0\n", - "8809585 0.5\n", - "8809586 0.5\n", - "8809587 0.0\n", - "8809588 0.0\n", - "8809589 0.0\n", - "8809590 0.0\n", - "8809591 0.0\n", - "8809592 0.0\n", - "8809593 0.0\n", - "8809594 0.0\n", - "8809595 0.0\n", - "8809596 0.0\n", - "8809597 0.0\n", - "8809598 0.0\n", - "8809599 0.0\n", - "8809600 0.0\n", - "8809601 0.0\n", - "8809602 0.0\n", - "Name: y, dtype: float64" - ] - }, - "execution_count": 154, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 107, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([0.])" - ] - }, - "execution_count": 107, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "np.unique(mdf[:256+start_bin]['y'])" - ] - }, - { - "cell_type": "code", - "execution_count": 150, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "chr chr8\n", - "start 144691450\n", - "stop 144691650\n", - "y 1.0\n", - "celltype HepG2\n", - "Name: 8809571, dtype: object" - ] - }, - "execution_count": 150, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "pos_seed_region" - ] - }, - { - "cell_type": "code", - "execution_count": 98, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([ 600, 600, 600, ..., 135005900, 135005900,\n", - " 135005900])" - ] - }, - "execution_count": 98, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# arr = metadata_df[mdf_msk]['start']\n", - "#arr == \n", - "np.sort(arr)" - ] - }, - { - "cell_type": "code", - "execution_count": 69, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "116.39193439483643\n" - ] - } - ], - "source": [ - "itime = time.time()\n", - "lts = ['{}:{}-{}'.format(x[0], x[1], x[2]) for x in zip(metadata_df['chr'], metadata_df['start'], metadata_df['stop'])]\n", - "print(time.time() - itime)" - ] - }, - { - "cell_type": "code", - "execution_count": 55, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "202800" - ] - }, - "execution_count": 55, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "#_metadata_df['y']\n", - "# s = metadata_df.iloc[np.array(pos_msk), :]\n", - "ntry = s.iloc[5]\n", - "ntry['start'] + 12800\n", - "# s['chr'], s['start'], s['stop'] # np.unique(s['chr'], return_counts=True)\n", - "# all_df\n", - "# metadata_df" - ] - }, - { - "cell_type": "code", - "execution_count": 16, + "execution_count": 13, "metadata": { - "collapsed": true, "jupyter": { - "outputs_hidden": true + "source_hidden": true } }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "H1-hESC 7.871182918548584\n", - "H1-hESC 8.298529148101807\n", - "H1-hESC 9.57175898551941\n", - "HCT116 17.01794719696045\n", - "HCT116 17.36267113685608\n", - "HCT116 18.669682025909424\n", - "HeLa-S3 26.405478954315186\n", - "HeLa-S3 26.759119272232056\n", - "HeLa-S3 28.043395042419434\n", - "HepG2 35.623862981796265\n", - "HepG2 35.98245143890381\n", - "HepG2 37.29869079589844\n", - "K562 44.92080807685852\n", - "K562 45.256179332733154\n", - "K562 46.7364935874939\n", - "A549 54.39264512062073\n", - "A549 54.74424934387207\n", - "A549 56.03351712226868\n", - "GM12878 63.745240211486816\n", - "GM12878 64.1029920578003\n", - "GM12878 65.43286633491516\n" - ] - } - ], + "outputs": [], "source": [ "train_regions_mask = np.isin(_metadata_df['chr'], _train_chroms)\n", - "val_regions_mask = np.isin(_metadata_df['chr'], _test_chroms)\n", + "val_regions_mask = np.isin(_metadata_df['chr'], _val_chroms)\n", + "test_regions_mask = np.isin(_metadata_df['chr'], _test_chroms)\n", "train_celltype_mask = np.isin(_metadata_df['celltype'], _train_celltypes)\n", "val_celltype_mask = np.isin(_metadata_df['celltype'], _val_celltype)\n", "test_celltype_mask = np.isin(_metadata_df['celltype'], _test_celltype)\n", "\n", "split_array = -1*np.ones(_metadata_df.shape[0]).astype(int)\n", "split_array[np.logical_and(train_regions_mask, train_celltype_mask)] = _split_dict['train']\n", - "split_array[np.logical_and(val_regions_mask, test_celltype_mask)] = _split_dict['test']\n", - "# Validate using test chr, either using a designated validation cell line ('val') or a training cell line ('id_val')\n", + "split_array[np.logical_and(test_regions_mask, test_celltype_mask)] = _split_dict['test']\n", + "# Validate using validation chr, either using a designated validation cell line ('val') or a training cell line ('id_val')\n", "split_array[np.logical_and(val_regions_mask, val_celltype_mask)] = _split_dict['val']\n", "split_array[np.logical_and(val_regions_mask, train_celltype_mask)] = _split_dict['id_val']\n", "\n", @@ -1039,19 +572,22 @@ "else:\n", " raise ValueError(f'Split scheme {_split_scheme} not recognized')\n", "\n", + "metadata_mask = (_metadata_df['split'] != -1)\n", "_metadata_df = _metadata_df[_metadata_df['split'] != -1]\n", - "_split_array = _metadata_df['split'].values\n", "\n", "chr_ints = _metadata_df['chr'].replace(dict( [(y, x) for x, y in enumerate(_metadata_map['chr'])] )).values\n", "celltype_ints = _metadata_df['celltype'].replace(dict( [(y, x) for x, y in enumerate(_metadata_map['celltype'])] )).values\n", - "_y_array = torch.LongTensor(np.array(_metadata_df['y']))\n", + "_split_array = _metadata_df['split'].values\n", + "\n", + "_y_array = torch.Tensor(np.load(_data_dir + 'labels/MAX/metadata_y.npy'))\n", + "_y_array = _y_array[metadata_mask]\n", "\n", "_metadata_array = torch.stack(\n", " (torch.LongTensor(chr_ints), \n", - " torch.LongTensor(celltype_ints), \n", - " _y_array),\n", + " torch.LongTensor(celltype_ints)\n", + " ),\n", " dim=1)\n", - "_metadata_fields = ['chr', 'celltype', 'y']" + "_metadata_fields = ['chr', 'celltype']" ] }, { @@ -1063,8 +599,12 @@ }, { "cell_type": "code", - "execution_count": 3, - "metadata": {}, + "execution_count": 24, + "metadata": { + "jupyter": { + "source_hidden": true + } + }, "outputs": [], "source": [ "import os, time\n", @@ -1099,17 +639,17 @@ " self._version = '1.0'\n", " self._download_url = 'https://worksheets.codalab.org/rest/bundles/0x8b3255e21e164cd98d3aeec09cd0bc26/contents/blob/'\n", " self._data_dir = self.initialize_data_dir(root_dir, download)\n", - " self._y_size = 1\n", - " self._n_classes = 2\n", + " self._y_size = 128\n", + " # self._n_classes = 2\n", " \n", - " # self._train_chroms = ['chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX']\n", - " self._train_chroms = ['chr2', 'chr9', 'chr11']\n", + " self._train_chroms = ['chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr10', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX']\n", + " self._val_chroms = ['chr2', 'chr9', 'chr11']\n", " self._test_chroms = ['chr1', 'chr8', 'chr21']\n", " self._transcription_factor = 'MAX'\n", " self._train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']\n", " self._val_celltype = ['A549']\n", " self._test_celltype = ['GM12878']\n", - " self._all_chroms = self._train_chroms + self._test_chroms\n", + " self._all_chroms = self._train_chroms + self._val_chroms + self._test_chroms\n", " self._all_celltypes = self._train_celltypes + self._val_celltype + self._test_celltype\n", " \n", " self._metadata_map = {}\n", @@ -1143,6 +683,9 @@ " print(chrom, time.time() - itime)\n", " \n", " self._dnase_allcelltypes = {}\n", + " ct = 'avg'\n", + " dnase_avg_bw_path = os.path.join(self._data_dir, 'Leopard_dnase/{}.bigwig'.format(ct))\n", + " self._dnase_allcelltypes[ct] = pyBigWig.open(dnase_avg_bw_path)\n", " for ct in self._all_celltypes:\n", " \"\"\"\n", " dnase_filename = os.path.join(self._data_dir, '{}_dnase.npz'.format(ct))\n", @@ -1151,63 +694,25 @@ " for chrom in self._all_chroms: #self._seq_bp:\n", " self._dnase_allcelltypes[ct][chrom] = dnase_npz_contents[chrom]\n", " \"\"\"\n", - " self._dnase_allcelltypes[ct] = os.path.join(self._data_dir, 'DNASE.{}.fc.signal.bigwig'.format(ct))\n", - " print(ct, time.time() - itime)\n", - " \n", - " # Read in metadata dataframe from training+validation data\n", - " train_regions_labeled = pd.read_csv(os.path.join(self._data_dir, 'labels/{}.train.labels.tsv.gz'.format(self._transcription_factor)), sep='\\t')\n", - " val_regions_labeled = pd.read_csv(os.path.join(self._data_dir, 'labels/{}.val.labels.tsv.gz'.format(self._transcription_factor)), sep='\\t')\n", - " training_df = train_regions_labeled[np.isin(train_regions_labeled['chr'], self._train_chroms)]\n", - " val_df = val_regions_labeled[np.isin(val_regions_labeled['chr'], self._test_chroms)]\n", - " all_df = pd.concat([training_df, val_df])\n", + " dnase_bw_path = os.path.join(self._data_dir, 'Leopard_dnase/{}.bigwig'.format(ct))\n", + " self._dnase_allcelltypes[ct] = pyBigWig.open(dnase_bw_path)\n", " \n", - " # Get the y values, and remove ambiguous labels by default.\n", - " pd_list = []\n", - " for ct in _all_celltypes:\n", - " tc_chr = all_df[['chr', 'start', 'stop', ct]]\n", - " tc_chr.columns = ['chr', 'start', 'stop', 'y']\n", - " y_array = tc_chr['y'].replace({'U': 0, 'B': 1, 'A': -1}).values\n", - "\n", - " # Now filter out ambiguous labels\n", - " non_ambig_mask = (y_array != -1)\n", - " tc_chr['y'] = y_array\n", - " tc_chr = tc_chr[non_ambig_mask]\n", - "\n", - " tc_chr.insert(len(tc_chr.columns), 'celltype', ct)\n", - " pd_list.append(tc_chr)\n", - " print(time.time() - itime)\n", - " self._metadata_df = pd.concat(pd_list)\n", - " \n", - " # Downsample negatives to balance each celltype\n", - " samp_ndces = []\n", - " itime = time.time()\n", - " neg_msk = (self._metadata_df['y'] == 0)\n", - " pos_msk = (self._metadata_df['y'] == 1)\n", - " for ct in _all_celltypes:\n", - " celltype_msk = (self._metadata_df['celltype'] == ct)\n", - " print(ct, time.time() - itime)\n", - " neg_ct_msk = np.logical_and(celltype_msk, neg_msk)\n", - " pos_ct_msk = np.logical_and(celltype_msk, pos_msk)\n", - " print(ct, time.time() - itime)\n", - " neg_ndces = np.where(neg_ct_msk)[0]\n", - " pos_ndces = np.where(pos_ct_msk)[0]\n", - " np.random.seed(42)\n", - " samp_neg_ndces = np.random.choice(neg_ndces, size=len(pos_ndces), replace=False)\n", - " samp_ndces.extend(samp_neg_ndces)\n", - " samp_ndces.extend(pos_ndces)\n", - " print(ct, time.time() - itime)\n", - " self._metadata_df = self._metadata_df.iloc[samp_ndces, :]\n", + " self._metadata_df = pd.read_csv(\n", + " self._data_dir + '/labels/MAX/metadata_df.bed', sep='\\t', header=None, \n", + " index_col=None, names=['chr', 'start', 'stop', 'celltype']\n", + " )\n", " \n", " train_regions_mask = np.isin(self._metadata_df['chr'], self._train_chroms)\n", - " val_regions_mask = np.isin(self._metadata_df['chr'], self._test_chroms)\n", + " val_regions_mask = np.isin(self._metadata_df['chr'], self._val_chroms)\n", + " test_regions_mask = np.isin(self._metadata_df['chr'], self._test_chroms)\n", " train_celltype_mask = np.isin(self._metadata_df['celltype'], self._train_celltypes)\n", " val_celltype_mask = np.isin(self._metadata_df['celltype'], self._val_celltype)\n", " test_celltype_mask = np.isin(self._metadata_df['celltype'], self._test_celltype)\n", " \n", " split_array = -1*np.ones(self._metadata_df.shape[0]).astype(int)\n", " split_array[np.logical_and(train_regions_mask, train_celltype_mask)] = self._split_dict['train']\n", - " split_array[np.logical_and(val_regions_mask, test_celltype_mask)] = self._split_dict['test']\n", - " # Validate using test chr, either using a designated validation cell line ('val') or a training cell line ('id_val')\n", + " split_array[np.logical_and(test_regions_mask, test_celltype_mask)] = self._split_dict['test']\n", + " # Validate using validation chr, either using a designated validation cell line ('val') or a training cell line ('id_val')\n", " split_array[np.logical_and(val_regions_mask, val_celltype_mask)] = self._split_dict['val']\n", " split_array[np.logical_and(val_regions_mask, train_celltype_mask)] = self._split_dict['id_val']\n", " \n", @@ -1216,19 +721,21 @@ " else:\n", " raise ValueError(f'Split scheme {self._split_scheme} not recognized')\n", " \n", + " metadata_mask = (self._metadata_df['split'] != -1)\n", " self._metadata_df = self._metadata_df[self._metadata_df['split'] != -1]\n", - " self._split_array = self._metadata_df['split'].values\n", " \n", " chr_ints = self._metadata_df['chr'].replace(dict( [(y, x) for x, y in enumerate(self._metadata_map['chr'])] )).values\n", " celltype_ints = self._metadata_df['celltype'].replace(dict( [(y, x) for x, y in enumerate(self._metadata_map['celltype'])] )).values\n", - " self._y_array = torch.LongTensor(np.array(self._metadata_df['y']))\n", + " self._split_array = self._metadata_df['split'].values\n", + " self._y_array = torch.Tensor(np.load(self._data_dir + '/labels/MAX/metadata_y.npy'))\n", + " self._y_array = self._y_array[metadata_mask]\n", " \n", " self._metadata_array = torch.stack(\n", " (torch.LongTensor(chr_ints), \n", - " torch.LongTensor(celltype_ints), \n", - " self._y_array),\n", + " torch.LongTensor(celltype_ints)\n", + " ),\n", " dim=1)\n", - " self._metadata_fields = ['chr', 'celltype', 'y']\n", + " self._metadata_fields = ['chr', 'celltype']\n", " \n", " self._eval_grouper = CombinatorialGrouper(\n", " dataset=self,\n", @@ -1237,33 +744,43 @@ " self._metric = Accuracy()\n", " \n", " super().__init__(root_dir, download, split_scheme)\n", - "\n", - " def get_input(self, idx):\n", + " \n", + " \"\"\"\n", + " def get_random_label_vec(metadata_df, output_size=128):\n", + " # Sample a positively labeled region at random\n", + " pos_mdf = metadata_df[metadata_df['y'] == 1] #.iloc[ metadata_df['chr'] == s['chr'], : ]\n", + " pos_seed_region = pos_mdf.iloc[np.random.randint(pos_mdf.shape[0])]\n", + "\n", + " # Extract regions from this chromosome in this celltype, to get a window of labels from\n", + " chr_msk = np.array(metadata_df['chr']) == pos_seed_region['chr']\n", + " ct_msk = np.array(metadata_df['celltype']) == pos_seed_region['celltype']\n", + " mdf = metadata_df[chr_msk & ct_msk]\n", + "\n", + " # Get labels\n", + " start_ndx = np.where(mdf['start'] == pos_seed_region['start'])[0][0]\n", + " y_label_vec = mdf.iloc[start_ndx:start_ndx+output_size, :]['y']\n", + " \"\"\"\n", + " \n", + " def get_input(self, idx, window_size=12800):\n", " \"\"\"\n", " Returns x for a given idx in metadata_array, which has been filtered to only take windows with the desired stride.\n", " Computes this from: \n", " (1) sequence features in self._seq_bp\n", - " (2) DNase bigwig file paths in self._dnase_allcelltypes\n", - " (3) Metadata for the index (location along the genome with 200bp window width)\n", + " (2) DNase bigwig file handles in self._dnase_allcelltypes\n", + " (3) Metadata for the index (location along the genome with 6400bp window width)\n", + " (4) Window_size, the length of sequence returned (centered on the 6400bp region in (3))\n", " \"\"\"\n", - " \n", " this_metadata = self._metadata_df.iloc[idx, :]\n", - " \"\"\"\n", - " flank_size = 400\n", - " interval_start = this_metadata['start'] - flank_size\n", - " interval_end = this_metadata['stop'] + flank_size\n", - " dnase_this = self._dnase_allcelltypes[this_metadata['celltype']][this_metadata['chr']][interval_start:interval_end]\n", - " seq_this = self._seq_bp[this_metadata['chr']][interval_start:interval_end]\n", - " return torch.tensor(np.column_stack([seq_this, dnase_this]))\n", - " \"\"\"\n", - " window_size = 12800\n", - " interval_start = this_metadata['start']\n", - " interval_end = this_metadata['stop'] + window_size\n", + " interval_start = this_metadata['start'] - int(window_size/4)\n", + " interval_end = interval_start + window_size #this_metadata['stop']\n", " seq_this = self._seq_bp[this_metadata['chr']][interval_start:interval_end]\n", " dnase_bw = self._dnase_allcelltypes[this_metadata['celltype']]\n", " dnase_this = dnase_bw.values(chrom, interval_start, interval_end, numpy=True)\n", - " return torch.tensor(np.column_stack([seq_this, dnase_this]))\n", - " \n", + " dnase_avg = self._dnase_allcelltypes['avg'].values(chrom, interval_start, interval_end, numpy=True)\n", + " return torch.tensor(np.column_stack(\n", + " [np.nan_to_num(seq_this), np.nan_to_num(dnase_this), np.nan_to_num(dnase_avg)]\n", + " ))\n", + "\n", " def eval(self, y_pred, y_true, metadata):\n", " return self.standard_group_eval(\n", " self._metric,\n", @@ -1273,7 +790,7 @@ }, { "cell_type": "code", - "execution_count": 107, + "execution_count": 26, "metadata": { "collapsed": true, "jupyter": { @@ -1285,40 +802,29 @@ "name": "stdout", "output_type": "stream", "text": [ - "chr2 3.962329387664795\n", - "chr9 6.259538888931274\n", - "chr11 8.446826934814453\n", - "chr1 12.49940538406372\n", - "chr8 14.91869592666626\n", - "chr21 15.700694799423218\n", - "H1-hESC 23.95099449157715\n", - "HCT116 31.26502823829651\n", - "HeLa-S3 39.382277488708496\n", - "HepG2 47.24500226974487\n", - "K562 55.079211711883545\n", - "A549 62.405343532562256\n", - "GM12878 70.00356984138489\n", - "H1-hESC 8.160386562347412\n", - "H1-hESC 8.546203374862671\n", - "H1-hESC 9.868412971496582\n", - "HCT116 17.121587991714478\n", - "HCT116 17.524660110473633\n", - "HCT116 18.90956425666809\n", - "HeLa-S3 26.98938488960266\n", - "HeLa-S3 27.376858234405518\n", - "HeLa-S3 28.7989022731781\n", - "HepG2 36.29348182678223\n", - "HepG2 36.668752908706665\n", - "HepG2 38.151512145996094\n", - "K562 45.96789216995239\n", - "K562 46.33995985984802\n", - "K562 47.87280249595642\n", - "A549 55.380892276763916\n", - "A549 55.75924301147461\n", - "A549 57.22686314582825\n", - "GM12878 65.09361720085144\n", - "GM12878 65.50619888305664\n", - "GM12878 66.9196424484253\n" + "chr3 3.0425407886505127\n", + "chr4 5.967821359634399\n", + "chr5 8.747126340866089\n", + "chr6 11.370141744613647\n", + "chr7 13.802208423614502\n", + "chr10 15.875979900360107\n", + "chr12 17.929850339889526\n", + "chr13 19.67976665496826\n", + "chr14 21.306750059127808\n", + "chr15 22.866544723510742\n", + "chr16 24.241100788116455\n", + "chr17 25.480982303619385\n", + "chr18 26.677065134048462\n", + "chr19 27.579110622406006\n", + "chr20 28.545915603637695\n", + "chr22 29.323810577392578\n", + "chrX 31.698036670684814\n", + "chr2 35.40705943107605\n", + "chr9 37.5518524646759\n", + "chr11 39.61783218383789\n", + "chr1 43.411964893341064\n", + "chr8 45.64823389053345\n", + "chr21 46.377281188964844\n" ] } ], @@ -1331,91 +837,28 @@ ] }, { - "cell_type": "code", - "execution_count": 2, + "cell_type": "markdown", "metadata": {}, - "outputs": [ - { - "ename": "ModuleNotFoundError", - "evalue": "No module named 'pyBigWig'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# full_dataset_encode\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mpyBigWig\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'pyBigWig'" - ] - } - ], "source": [ - "# full_dataset_encode\n", - "import pyBigWig" + "# Initialize algorithm" ] }, { "cell_type": "code", - "execution_count": 39, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "a = np.random.choice(1210796, size=128)\n", - "seta = [full_dataset_encode.get_input(x) for x in a]\n", - "seta[0].shape\n", - "\n", - "# full_dataset = copy.deepcopy(full_dataset_encode)\n", - "# full_dataset = copy.deepcopy(full_dataset_camelyon17)\n", - "# full_dataset_camelyon17.split_dict\n", - "\n", - "# full_dataset" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Initialize algorithm" + "config" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Train data...\n", - " hospital = 0: n = 53425\n", - " hospital = 1: n = 0\n", - " hospital = 2: n = 0\n", - " hospital = 3: n = 116959\n", - " hospital = 4: n = 132052\n", - "Validation (ID) data...\n", - " hospital = 0: n = 6011\n", - " hospital = 1: n = 0\n", - " hospital = 2: n = 0\n", - " hospital = 3: n = 12879\n", - " hospital = 4: n = 14670\n", - "Test data...\n", - " hospital = 0: n = 0\n", - " hospital = 1: n = 0\n", - " hospital = 2: n = 85054\n", - " hospital = 3: n = 0\n", - " hospital = 4: n = 0\n", - "Validation (OOD) data...\n", - " hospital = 0: n = 0\n", - " hospital = 1: n = 34904\n", - " hospital = 2: n = 0\n", - " hospital = 3: n = 0\n", - " hospital = 4: n = 0\n", - "Dout: 2\n" - ] - } - ], + "outputs": [], "source": [ - "config = config_camelyon\n", - "\n", + "# config = config_encode\n", "\n", "train_grouper = CombinatorialGrouper(\n", " dataset=full_dataset,\n", @@ -1488,6 +931,77 @@ " train_grouper=train_grouper)" ] }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Available objects for config:\n", + " AliasManager\n", + " DisplayFormatter\n", + " HistoryManager\n", + " IPCompleter\n", + " IPKernelApp\n", + " LoggingMagics\n", + " MagicsManager\n", + " OSMagics\n", + " PrefilterManager\n", + " ScriptMagics\n", + " StoreMagics\n", + " ZMQInteractiveShell\n" + ] + } + ], + "source": [ + "config" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "device(type='cuda', index=0)" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "config.device" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'np' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfull_dataset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_input\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;31m#\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mNameError\u001b[0m: name 'np' is not defined" + ] + } + ], + "source": [ + "np.array(full_dataset.get_input(0)).shape\n", + "#" + ] + }, { "cell_type": "code", "execution_count": 29, @@ -1545,7 +1059,12 @@ { "cell_type": "code", "execution_count": 30, - "metadata": {}, + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + } + }, "outputs": [ { "data": { @@ -1678,33 +1197,22 @@ }, { "cell_type": "code", - "execution_count": 126, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 126, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "for b in full_dataset:\n", - " break" - ] + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] }, { "cell_type": "code", "execution_count": 33, - "metadata": { - "jupyter": { - "source_hidden": true - } - }, + "metadata": {}, "outputs": [], "source": [ "\n", diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index 04f5d08d..b5657597 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -33,17 +33,17 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): self._version = '1.0' self._download_url = 'https://worksheets.codalab.org/rest/bundles/0x8b3255e21e164cd98d3aeec09cd0bc26/contents/blob/' self._data_dir = self.initialize_data_dir(root_dir, download) - self._y_size = 1 - self._n_classes = 2 + self._y_size = 128 + # self._n_classes = 2 - self._train_chroms = ['chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX'] - # self._train_chroms = ['chr2', 'chr9', 'chr11'] + self._train_chroms = ['chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr10', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX'] + self._val_chroms = ['chr2', 'chr9', 'chr11'] self._test_chroms = ['chr1', 'chr8', 'chr21'] self._transcription_factor = 'MAX' self._train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562'] self._val_celltype = ['A549'] self._test_celltype = ['GM12878'] - self._all_chroms = self._train_chroms + self._test_chroms + self._all_chroms = self._train_chroms + self._val_chroms + self._test_chroms self._all_celltypes = self._train_celltypes + self._val_celltype + self._test_celltype self._metadata_map = {} @@ -91,60 +91,22 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): dnase_bw_path = os.path.join(self._data_dir, 'Leopard_dnase/{}.bigwig'.format(ct)) self._dnase_allcelltypes[ct] = pyBigWig.open(dnase_bw_path) - # Read in metadata dataframe from training+validation data - train_regions_labeled = pd.read_csv(os.path.join(self._data_dir, 'labels/{}.train.labels.tsv.gz'.format(self._transcription_factor)), sep='\t') - val_regions_labeled = pd.read_csv(os.path.join(self._data_dir, 'labels/{}.val.labels.tsv.gz'.format(self._transcription_factor)), sep='\t') - training_df = train_regions_labeled[np.isin(train_regions_labeled['chr'], self._train_chroms)] - val_df = val_regions_labeled[np.isin(val_regions_labeled['chr'], self._test_chroms)] - all_df = pd.concat([training_df, val_df]) - - # Get the y values, and remove ambiguous labels by default. - pd_list = [] - for ct in self._all_celltypes: - tc_chr = all_df[['chr', 'start', 'stop', ct]] - tc_chr.columns = ['chr', 'start', 'stop', 'y'] - y_array = tc_chr['y'].replace({'U': 0, 'B': 1, 'A': -1}).values - - # Now filter out ambiguous labels - non_ambig_mask = (y_array != -1) - tc_chr['y'] = y_array - tc_chr = tc_chr[non_ambig_mask] - - tc_chr.insert(len(tc_chr.columns), 'celltype', ct) - pd_list.append(tc_chr) - print(time.time() - itime) - self._metadata_df = pd.concat(pd_list) - - # Downsample negatives to balance each celltype - samp_ndces = [] - itime = time.time() - neg_msk = (self._metadata_df['y'] == 0) - pos_msk = (self._metadata_df['y'] == 1) - for ct in _all_celltypes: - celltype_msk = (self._metadata_df['celltype'] == ct) - print(ct, time.time() - itime) - neg_ct_msk = np.logical_and(celltype_msk, neg_msk) - pos_ct_msk = np.logical_and(celltype_msk, pos_msk) - print(ct, time.time() - itime) - neg_ndces = np.where(neg_ct_msk)[0] - pos_ndces = np.where(pos_ct_msk)[0] - np.random.seed(42) - samp_neg_ndces = np.random.choice(neg_ndces, size=len(pos_ndces), replace=False) - samp_ndces.extend(samp_neg_ndces) - samp_ndces.extend(pos_ndces) - print(ct, time.time() - itime) - self._metadata_df = self._metadata_df.iloc[samp_ndces, :] + self._metadata_df = pd.read_csv( + self._data_dir + '/labels/MAX/metadata_df.bed', sep='\t', header=None, + index_col=None, names=['chr', 'start', 'stop', 'celltype'] + ) train_regions_mask = np.isin(self._metadata_df['chr'], self._train_chroms) - val_regions_mask = np.isin(self._metadata_df['chr'], self._test_chroms) + val_regions_mask = np.isin(self._metadata_df['chr'], self._val_chroms) + test_regions_mask = np.isin(self._metadata_df['chr'], self._test_chroms) train_celltype_mask = np.isin(self._metadata_df['celltype'], self._train_celltypes) val_celltype_mask = np.isin(self._metadata_df['celltype'], self._val_celltype) test_celltype_mask = np.isin(self._metadata_df['celltype'], self._test_celltype) split_array = -1*np.ones(self._metadata_df.shape[0]).astype(int) split_array[np.logical_and(train_regions_mask, train_celltype_mask)] = self._split_dict['train'] - split_array[np.logical_and(val_regions_mask, test_celltype_mask)] = self._split_dict['test'] - # Validate using test chr, either using a designated validation cell line ('val') or a training cell line ('id_val') + split_array[np.logical_and(test_regions_mask, test_celltype_mask)] = self._split_dict['test'] + # Validate using validation chr, either using a designated validation cell line ('val') or a training cell line ('id_val') split_array[np.logical_and(val_regions_mask, val_celltype_mask)] = self._split_dict['val'] split_array[np.logical_and(val_regions_mask, train_celltype_mask)] = self._split_dict['id_val'] @@ -153,19 +115,21 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): else: raise ValueError(f'Split scheme {self._split_scheme} not recognized') + metadata_mask = (self._metadata_df['split'] != -1) self._metadata_df = self._metadata_df[self._metadata_df['split'] != -1] - self._split_array = self._metadata_df['split'].values chr_ints = self._metadata_df['chr'].replace(dict( [(y, x) for x, y in enumerate(self._metadata_map['chr'])] )).values celltype_ints = self._metadata_df['celltype'].replace(dict( [(y, x) for x, y in enumerate(self._metadata_map['celltype'])] )).values - self._y_array = torch.LongTensor(np.array(self._metadata_df['y'])) + self._split_array = self._metadata_df['split'].values + self._y_array = torch.Tensor(np.load(self._data_dir + '/labels/MAX/metadata_y.npy')) + self._y_array = self._y_array[metadata_mask] self._metadata_array = torch.stack( (torch.LongTensor(chr_ints), - torch.LongTensor(celltype_ints), - self._y_array), + torch.LongTensor(celltype_ints) + ), dim=1) - self._metadata_fields = ['chr', 'celltype', 'y'] + self._metadata_fields = ['chr', 'celltype'] self._eval_grouper = CombinatorialGrouper( dataset=self, @@ -174,7 +138,8 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): self._metric = Accuracy() super().__init__(root_dir, download, split_scheme) - + + """ def get_random_label_vec(metadata_df, output_size=128): # Sample a positively labeled region at random pos_mdf = metadata_df[metadata_df['y'] == 1] #.iloc[ metadata_df['chr'] == s['chr'], : ] @@ -188,33 +153,27 @@ def get_random_label_vec(metadata_df, output_size=128): # Get labels start_ndx = np.where(mdf['start'] == pos_seed_region['start'])[0][0] y_label_vec = mdf.iloc[start_ndx:start_ndx+output_size, :]['y'] + """ - def get_input(self, idx): + def get_input(self, idx, window_size=12800): """ Returns x for a given idx in metadata_array, which has been filtered to only take windows with the desired stride. Computes this from: (1) sequence features in self._seq_bp (2) DNase bigwig file handles in self._dnase_allcelltypes - (3) Metadata for the index (location along the genome with 200bp window width) + (3) Metadata for the index (location along the genome with 6400bp window width) + (4) Window_size, the length of sequence returned (centered on the 6400bp region in (3)) """ - this_metadata = self._metadata_df.iloc[idx, :] - """ - flank_size = 400 - interval_start = this_metadata['start'] - flank_size - interval_end = this_metadata['stop'] + flank_size - dnase_this = self._dnase_allcelltypes[this_metadata['celltype']][this_metadata['chr']][interval_start:interval_end] - seq_this = self._seq_bp[this_metadata['chr']][interval_start:interval_end] - return torch.tensor(np.column_stack([seq_this, dnase_this])) - """ - window_size = 12800 - interval_start = this_metadata['start'] - interval_end = window_size + interval_start #this_metadata['stop'] + interval_start = this_metadata['start'] - int(window_size/4) + interval_end = interval_start + window_size #this_metadata['stop'] seq_this = self._seq_bp[this_metadata['chr']][interval_start:interval_end] dnase_bw = self._dnase_allcelltypes[this_metadata['celltype']] dnase_this = dnase_bw.values(chrom, interval_start, interval_end, numpy=True) dnase_avg = self._dnase_allcelltypes['avg'].values(chrom, interval_start, interval_end, numpy=True) - return torch.tensor(np.column_stack([seq_this, dnase_this, dnase_avg])) + return torch.tensor(np.column_stack( + [np.nan_to_num(seq_this), np.nan_to_num(dnase_this), np.nan_to_num(dnase_avg)] + )) def eval(self, y_pred, y_true, metadata): return self.standard_group_eval( From 96aa06394e2ad839db2e4b8ee93b5145545258cb Mon Sep 17 00:00:00 2001 From: aikanor Date: Fri, 5 Mar 2021 10:26:05 -0800 Subject: [PATCH 028/244] final code (1/3) --- examples/models/CNN_genome.py | 4 +- examples/models/initializer.py | 2 +- examples/sbox_run_expt.ipynb | 671 +++++++++++++++++++++++---- wilds/datasets/encodetfbs_dataset.py | 4 +- 4 files changed, 580 insertions(+), 101 deletions(-) diff --git a/examples/models/CNN_genome.py b/examples/models/CNN_genome.py index 147f8c9e..4f851706 100644 --- a/examples/models/CNN_genome.py +++ b/examples/models/CNN_genome.py @@ -26,7 +26,7 @@ def double_conv(in_channels, out_channels): class UNet(nn.Module): - def __init__(self, n_class, n_channels_in=6): + def __init__(self, out_features=16, n_channels_in=6): super().__init__() self.dconv_down1 = double_conv(n_channels_in, 15) @@ -46,7 +46,7 @@ def __init__(self, n_class, n_channels_in=6): self.dconv_up2 = double_conv(22 + 33, 22) self.dconv_up1 = double_conv(15 + 22, 15) - self.conv_last = nn.Conv2d(15, n_class, 1) + self.conv_last = nn.Conv1d(15, out_features, 1) def forward(self, x): diff --git a/examples/models/initializer.py b/examples/models/initializer.py index e37a4250..10dd5693 100644 --- a/examples/models/initializer.py +++ b/examples/models/initializer.py @@ -25,7 +25,7 @@ def initialize_model(config, d_out): num_labels=d_out, **config.model_kwargs) elif config.model == 'leopard': - model = UNet(d_out) + model = UNet(out_features=d_out) elif config.model == 'logistic_regression': model = nn.Linear(out_features=d_out, **config.model_kwargs) elif config.model == 'gin-virtual': diff --git a/examples/sbox_run_expt.ipynb b/examples/sbox_run_expt.ipynb index 071a68d7..92cb1746 100644 --- a/examples/sbox_run_expt.ipynb +++ b/examples/sbox_run_expt.ipynb @@ -109,7 +109,7 @@ { "data": { "text/plain": [ - "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" + "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" ] }, "execution_count": 2, @@ -207,7 +207,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -219,20 +219,20 @@ "config_encode = parser.parse_args(argstr_encode.split())\n", "config_encode = populate_defaults(config_encode)\n", "\n", - "# config = config_camelyon\n", - "config = config_encode\n" + "config = config_camelyon\n", + "# config = config_encode\n" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Dataset: encode-tfbs\n", + "Dataset: camelyon17\n", "Algorithm: ERM\n", "Root dir: data\n", "Split scheme: official\n", @@ -244,26 +244,26 @@ "Uniform over groups: False\n", "Distinct groups: None\n", "N groups per batch: 2\n", - "Batch size: 64\n", + "Batch size: 32\n", "Eval loader: standard\n", - "Model: leopard\n", + "Model: densenet121\n", "Model kwargs: {'pretrained': False}\n", - "Train transform: None\n", - "Eval transform: None\n", - "Target resolution: None\n", + "Train transform: image_base\n", + "Eval transform: image_base\n", + "Target resolution: (224, 224)\n", "Resize scale: None\n", "Max token length: None\n", "Loss function: cross_entropy\n", - "Groupby fields: ['celltype']\n", + "Groupby fields: ['hospital']\n", "Group dro step size: None\n", - "Coral penalty weight: None\n", - "Irm lambda: None\n", + "Coral penalty weight: 0.1\n", + "Irm lambda: 1.0\n", "Irm penalty anneal iters: None\n", "Algo log metric: accuracy\n", "Val metric: acc_avg\n", "Val metric decreasing: False\n", "N epochs: 5\n", - "Optimizer: Adam\n", + "Optimizer: SGD\n", "Lr: 0.001\n", "Weight decay: 0.01\n", "Max grad norm: None\n", @@ -287,20 +287,7 @@ "Use wandb: False\n", "Progress bar: False\n", "Resume: False\n", - "\n", - "chr3 5.088634967803955\n", - "chr4 9.974164009094238\n", - "chr5 15.149149894714355\n", - "chr6 19.728455066680908\n", - "chr7 23.769655466079712\n", - "chr10 29.31521511077881\n", - "chr12 32.78225326538086\n", - "chr13 35.67028570175171\n", - "chr14 46.721638441085815\n", - "chr15 92.16564106941223\n", - "chr16 96.26218318939209\n", - "chr17 114.85105729103088\n", - "chr18 116.09504199028015\n" + "\n" ] } ], @@ -845,18 +832,62 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'algorithm' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0malgorithm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mNameError\u001b[0m: name 'algorithm' is not defined" + ] + } + ], "source": [ - "config" + "algorithm.model" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train data...\n", + " hospital = 0: n = 53425\n", + " hospital = 1: n = 0\n", + " hospital = 2: n = 0\n", + " hospital = 3: n = 116959\n", + " hospital = 4: n = 132052\n", + "Validation (ID) data...\n", + " hospital = 0: n = 6011\n", + " hospital = 1: n = 0\n", + " hospital = 2: n = 0\n", + " hospital = 3: n = 12879\n", + " hospital = 4: n = 14670\n", + "Test data...\n", + " hospital = 0: n = 0\n", + " hospital = 1: n = 0\n", + " hospital = 2: n = 85054\n", + " hospital = 3: n = 0\n", + " hospital = 4: n = 0\n", + "Validation (OOD) data...\n", + " hospital = 0: n = 0\n", + " hospital = 1: n = 34904\n", + " hospital = 2: n = 0\n", + " hospital = 3: n = 0\n", + " hospital = 4: n = 0\n", + "Dout: 2\n" + ] + } + ], "source": [ "# config = config_encode\n", "\n", @@ -933,73 +964,521 @@ }, { "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Available objects for config:\n", - " AliasManager\n", - " DisplayFormatter\n", - " HistoryManager\n", - " IPCompleter\n", - " IPKernelApp\n", - " LoggingMagics\n", - " MagicsManager\n", - " OSMagics\n", - " PrefilterManager\n", - " ScriptMagics\n", - " StoreMagics\n", - " ZMQInteractiveShell\n" - ] - } - ], - "source": [ - "config" - ] - }, - { - "cell_type": "code", - "execution_count": 13, + "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "device(type='cuda', index=0)" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "config.device" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "ename": "NameError", - "evalue": "name 'np' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfull_dataset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_input\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;31m#\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mNameError\u001b[0m: name 'np' is not defined" - ] - } - ], - "source": [ - "np.array(full_dataset.get_input(0)).shape\n", - "#" + "DenseNet(\n", + " (features): Sequential(\n", + " (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n", + " (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu0): ReLU(inplace=True)\n", + " (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n", + " (denseblock1): _DenseBlock(\n", + " (denselayer1): _DenseLayer(\n", + " (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer2): _DenseLayer(\n", + " (norm1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(96, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer3): _DenseLayer(\n", + " (norm1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer4): _DenseLayer(\n", + " (norm1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(160, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer5): _DenseLayer(\n", + " (norm1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer6): _DenseLayer(\n", + " (norm1): BatchNorm2d(224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(224, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " )\n", + " (transition1): _Transition(\n", + " (norm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): ReLU(inplace=True)\n", + " (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n", + " )\n", + " (denseblock2): _DenseBlock(\n", + " (denselayer1): _DenseLayer(\n", + " (norm1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer2): _DenseLayer(\n", + " (norm1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(160, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer3): _DenseLayer(\n", + " (norm1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer4): _DenseLayer(\n", + " (norm1): BatchNorm2d(224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(224, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer5): _DenseLayer(\n", + " (norm1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer6): _DenseLayer(\n", + " (norm1): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(288, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer7): _DenseLayer(\n", + " (norm1): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer8): _DenseLayer(\n", + " (norm1): BatchNorm2d(352, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(352, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer9): _DenseLayer(\n", + " (norm1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(384, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer10): _DenseLayer(\n", + " (norm1): BatchNorm2d(416, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(416, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer11): _DenseLayer(\n", + " (norm1): BatchNorm2d(448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(448, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer12): _DenseLayer(\n", + " (norm1): BatchNorm2d(480, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(480, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " )\n", + " (transition2): _Transition(\n", + " (norm): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): ReLU(inplace=True)\n", + " (conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n", + " )\n", + " (denseblock3): _DenseBlock(\n", + " (denselayer1): _DenseLayer(\n", + " (norm1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer2): _DenseLayer(\n", + " (norm1): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(288, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer3): _DenseLayer(\n", + " (norm1): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer4): _DenseLayer(\n", + " (norm1): BatchNorm2d(352, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(352, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer5): _DenseLayer(\n", + " (norm1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(384, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer6): _DenseLayer(\n", + " (norm1): BatchNorm2d(416, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(416, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer7): _DenseLayer(\n", + " (norm1): BatchNorm2d(448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(448, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer8): _DenseLayer(\n", + " (norm1): BatchNorm2d(480, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(480, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer9): _DenseLayer(\n", + " (norm1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer10): _DenseLayer(\n", + " (norm1): BatchNorm2d(544, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(544, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer11): _DenseLayer(\n", + " (norm1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer12): _DenseLayer(\n", + " (norm1): BatchNorm2d(608, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer13): _DenseLayer(\n", + " (norm1): BatchNorm2d(640, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(640, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer14): _DenseLayer(\n", + " (norm1): BatchNorm2d(672, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(672, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer15): _DenseLayer(\n", + " (norm1): BatchNorm2d(704, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(704, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer16): _DenseLayer(\n", + " (norm1): BatchNorm2d(736, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(736, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer17): _DenseLayer(\n", + " (norm1): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(768, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer18): _DenseLayer(\n", + " (norm1): BatchNorm2d(800, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(800, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer19): _DenseLayer(\n", + " (norm1): BatchNorm2d(832, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(832, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer20): _DenseLayer(\n", + " (norm1): BatchNorm2d(864, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(864, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer21): _DenseLayer(\n", + " (norm1): BatchNorm2d(896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer22): _DenseLayer(\n", + " (norm1): BatchNorm2d(928, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(928, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer23): _DenseLayer(\n", + " (norm1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(960, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer24): _DenseLayer(\n", + " (norm1): BatchNorm2d(992, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(992, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " )\n", + " (transition3): _Transition(\n", + " (norm): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): ReLU(inplace=True)\n", + " (conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n", + " )\n", + " (denseblock4): _DenseBlock(\n", + " (denselayer1): _DenseLayer(\n", + " (norm1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer2): _DenseLayer(\n", + " (norm1): BatchNorm2d(544, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(544, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer3): _DenseLayer(\n", + " (norm1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer4): _DenseLayer(\n", + " (norm1): BatchNorm2d(608, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer5): _DenseLayer(\n", + " (norm1): BatchNorm2d(640, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(640, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer6): _DenseLayer(\n", + " (norm1): BatchNorm2d(672, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(672, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer7): _DenseLayer(\n", + " (norm1): BatchNorm2d(704, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(704, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer8): _DenseLayer(\n", + " (norm1): BatchNorm2d(736, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(736, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer9): _DenseLayer(\n", + " (norm1): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(768, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer10): _DenseLayer(\n", + " (norm1): BatchNorm2d(800, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(800, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer11): _DenseLayer(\n", + " (norm1): BatchNorm2d(832, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(832, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer12): _DenseLayer(\n", + " (norm1): BatchNorm2d(864, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(864, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer13): _DenseLayer(\n", + " (norm1): BatchNorm2d(896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer14): _DenseLayer(\n", + " (norm1): BatchNorm2d(928, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(928, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer15): _DenseLayer(\n", + " (norm1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(960, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer16): _DenseLayer(\n", + " (norm1): BatchNorm2d(992, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(992, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " )\n", + " (norm5): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " (classifier): Linear(in_features=1024, out_features=2, bias=True)\n", + ")" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "algorithm.model" ] }, { diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index b5657597..23b70014 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -5,7 +5,7 @@ import pyBigWig from wilds.datasets.wilds_dataset import WILDSDataset from wilds.common.grouper import CombinatorialGrouper -from wilds.common.metrics.all_metrics import Accuracy +from wilds.common.metrics.all_metrics import Accuracy, MultiTaskAccuracy all_chrom_names = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX'] @@ -135,7 +135,7 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): dataset=self, groupby_fields=['celltype']) - self._metric = Accuracy() + self._metric = MultiTaskAccuracy() super().__init__(root_dir, download, split_scheme) From 6631381d7c4b1fd79b0dcaba16cdb55765a39b17 Mon Sep 17 00:00:00 2001 From: aikanor Date: Thu, 11 Mar 2021 07:57:37 -0800 Subject: [PATCH 029/244] final integration w/eval code 2/3 --- examples/configs/data_loader.py | 2 +- examples/configs/datasets.py | 7 +- examples/configs/model.py | 4 +- examples/configs/supported.py | 3 +- examples/models/CNN_genome.py | 24 +- examples/sbox_run_expt.ipynb | 1960 +++++--------------------- wilds/common/metrics/all_metrics.py | 30 +- wilds/common/metrics/metric.py | 1 + wilds/datasets/encodetfbs_dataset.py | 15 +- 9 files changed, 435 insertions(+), 1611 deletions(-) diff --git a/examples/configs/data_loader.py b/examples/configs/data_loader.py index 38741464..c00b1b64 100644 --- a/examples/configs/data_loader.py +++ b/examples/configs/data_loader.py @@ -1,6 +1,6 @@ loader_defaults = { 'loader_kwargs':{ - 'num_workers': 4, + 'num_workers': 1, 'pin_memory': True, }, 'n_groups_per_batch': 4, diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index 07219823..5072d5bc 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -95,19 +95,18 @@ 'model_kwargs': {'pretrained': False}, 'train_transform': None, 'eval_transform': None, - 'loss_function': 'cross_entropy', + 'loss_function': 'multitask_bce', 'groupby_fields': ['celltype'], 'val_metric': 'acc_avg', 'val_metric_decreasing': False, - 'optimizer': 'Adam', - # 'optimizer_kwargs': { }, + 'optimizer': 'Adam', 'scheduler': None, 'batch_size': 64, 'lr': 0.001, 'weight_decay': 0.01, 'n_epochs': 5, 'n_groups_per_batch': 2, - 'algo_log_metric': 'accuracy', + 'algo_log_metric': 'multitask_avgprec', # 'irm_lambda': 1.0, # 'coral_penalty_weight': 0.1, }, diff --git a/examples/configs/model.py b/examples/configs/model.py index 6d4c88c3..51618754 100644 --- a/examples/configs/model.py +++ b/examples/configs/model.py @@ -27,5 +27,7 @@ 'target_resolution': (224, 224), }, 'logistic_regression': {}, - 'leopard': {}, + 'leopard': { + 'optimizer': 'Adam' + }, } diff --git a/examples/configs/supported.py b/examples/configs/supported.py index bdf73267..bf7f73cc 100644 --- a/examples/configs/supported.py +++ b/examples/configs/supported.py @@ -16,7 +16,7 @@ from wilds.datasets.yelp_dataset import YelpDataset # metrics from wilds.common.metrics.loss import ElementwiseLoss, Loss, MultiTaskLoss -from wilds.common.metrics.all_metrics import Accuracy, MultiTaskAccuracy, MSE +from wilds.common.metrics.all_metrics import Accuracy, MultiTaskAccuracy, MSE, MultiTaskAveragePrecision datasets = { 'amazon': AmazonDataset, @@ -43,6 +43,7 @@ 'accuracy': Accuracy(), 'mse': MSE(), 'multitask_accuracy': MultiTaskAccuracy(), + 'multitask_avgprec': MultiTaskAveragePrecision(), None: None, } diff --git a/examples/models/CNN_genome.py b/examples/models/CNN_genome.py index 4f851706..7397eeb2 100644 --- a/examples/models/CNN_genome.py +++ b/examples/models/CNN_genome.py @@ -37,16 +37,23 @@ def __init__(self, out_features=16, n_channels_in=6): self.dconv_down6 = double_conv(73, 109) self.maxpool = nn.MaxPool1d(2) + # self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) self.conv_middle = single_conv(109, 109) + self.upsamp_6 = nn.ConvTranspose1d(109, 109, 2, stride=2) self.dconv_up5 = double_conv(73 + 109, 73) + self.upsamp_5 = nn.ConvTranspose1d(73, 73, 2, stride=2) self.dconv_up4 = double_conv(49 + 73, 49) + self.upsamp_4 = nn.ConvTranspose1d(49, 49, 2, stride=2) self.dconv_up3 = double_conv(33 + 49, 33) + self.upsamp_3 = nn.ConvTranspose1d(33, 33, 2, stride=2) self.dconv_up2 = double_conv(22 + 33, 22) + self.upsamp_2 = nn.ConvTranspose1d(22, 22, 2, stride=2) self.dconv_up1 = double_conv(15 + 22, 15) + self.upsamp_1 = nn.ConvTranspose1d(15, 15, 2, stride=2) - self.conv_last = nn.Conv1d(15, out_features, 1) + self.conv_last = nn.Conv1d(15, 1, 200, stride=50, padding=0) def forward(self, x): @@ -72,27 +79,28 @@ def forward(self, x): # Encoder finished. - x = self.upsample(conv6) # (input_size / 16) x 109 + x = self.upsamp_6(conv6) # (input_size / 16) x 109 x = torch.cat([x, conv5], dim=1) # (input_size / 16) x (109 + 73) x = self.dconv_up5(x) # (input_size / 16) x 73 - x = self.upsample(x) # (input_size / 8) x 73 + x = self.upsamp_5(x) # (input_size / 8) x 73 x = torch.cat([x, conv4], dim=1) # (input_size / 8) x (73 + 49) x = self.dconv_up4(x) # (input_size / 8) x 49 - x = self.upsample(x) # (input_size / 4) x 49 + x = self.upsamp_4(x) # (input_size / 4) x 49 x = torch.cat([x, conv3], dim=1) # (input_size / 4) x (49 + 33) x = self.dconv_up3(x) # (input_size / 4) x 33 - x = self.upsample(x) # (input_size / 2) x 33 + x = self.upsamp_3(x) # (input_size / 2) x 33 x = torch.cat([x, conv2], dim=1) # (input_size / 2) x (33 + 22) x = self.dconv_up2(x) # (input_size / 2) x 22 - x = self.upsample(x) # (input_size) x 22 + x = self.upsamp_2(x) # (input_size) x 22 x = torch.cat([x, conv1], dim=1) # (input_size) x (22 + 15) x = self.dconv_up1(x) # (input_size) x 15 - out = self.conv_last(x) + # middle 128 bits + out = self.conv_last(x)[:, :, 64:192] - return out + return torch.squeeze(out) diff --git a/examples/sbox_run_expt.ipynb b/examples/sbox_run_expt.ipynb index 92cb1746..5040aeb0 100644 --- a/examples/sbox_run_expt.ipynb +++ b/examples/sbox_run_expt.ipynb @@ -11,18 +11,14 @@ }, { "cell_type": "code", - "execution_count": 123, + "execution_count": 27, "metadata": {}, "outputs": [ { - "ename": "ModuleNotFoundError", - "evalue": "No module named 'psutil'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpsutil\u001b[0m\u001b[0;34m;\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpsutil\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mProcess\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgetpid\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmemory_info\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrss\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0;36m1024\u001b[0m \u001b[0;34m**\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'psutil'" + "name": "stdout", + "output_type": "stream", + "text": [ + "4860.4765625\n" ] } ], @@ -109,7 +105,7 @@ { "data": { "text/plain": [ - "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" + "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" ] }, "execution_count": 2, @@ -120,7 +116,7 @@ "source": [ "''' set default hyperparams in default_hyperparams.py '''\n", "parser = argparse.ArgumentParser()\n", - "CombinatorialGrouper\n", + "\n", "# Required arguments\n", "parser.add_argument('-d', '--dataset', choices=supported.datasets, required=True)\n", "parser.add_argument('--algorithm', required=True, choices=supported.algorithms)\n", @@ -207,11 +203,12 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "argstr_camelyon = \"--dataset camelyon17 --algorithm ERM --root_dir data\"\n", + "# argstr_camelyon = \"--dataset civilcomments --algorithm ERM --root_dir data\"\n", "config_camelyon = parser.parse_args(argstr_camelyon.split())\n", "config_camelyon = populate_defaults(config_camelyon)\n", "\n", @@ -220,7 +217,31 @@ "config_encode = populate_defaults(config_encode)\n", "\n", "config = config_camelyon\n", - "# config = config_encode\n" + "config = config_encode\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "argstr_camelyon = \"--dataset camelyon17 --algorithm ERM --root_dir data\"\n", + "# argstr_camelyon = \"--dataset civilcomments --algorithm ERM --root_dir data\"\n", + "config_camelyon = parser.parse_args(argstr_camelyon.split())\n", + "\n", + "argstr_encode = \"--dataset encode-tfbs --algorithm ERM --root_dir data\"\n", + "config_encode = parser.parse_args(argstr_encode.split())\n", + "config_encode" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "config.optimizer_kwargs = {}" ] }, { @@ -232,42 +253,42 @@ "name": "stdout", "output_type": "stream", "text": [ - "Dataset: camelyon17\n", + "Dataset: encode-tfbs\n", "Algorithm: ERM\n", "Root dir: data\n", "Split scheme: official\n", "Dataset kwargs: {}\n", "Download: False\n", "Frac: 1.0\n", - "Loader kwargs: {'num_workers': 4, 'pin_memory': True}\n", + "Loader kwargs: {'num_workers': 1, 'pin_memory': True}\n", "Train loader: standard\n", "Uniform over groups: False\n", "Distinct groups: None\n", "N groups per batch: 2\n", - "Batch size: 32\n", + "Batch size: 64\n", "Eval loader: standard\n", - "Model: densenet121\n", + "Model: leopard\n", "Model kwargs: {'pretrained': False}\n", - "Train transform: image_base\n", - "Eval transform: image_base\n", - "Target resolution: (224, 224)\n", + "Train transform: None\n", + "Eval transform: None\n", + "Target resolution: None\n", "Resize scale: None\n", "Max token length: None\n", - "Loss function: cross_entropy\n", - "Groupby fields: ['hospital']\n", + "Loss function: multitask_bce\n", + "Groupby fields: ['celltype']\n", "Group dro step size: None\n", - "Coral penalty weight: 0.1\n", - "Irm lambda: 1.0\n", + "Coral penalty weight: None\n", + "Irm lambda: None\n", "Irm penalty anneal iters: None\n", - "Algo log metric: accuracy\n", + "Algo log metric: multitask_avgprec\n", "Val metric: acc_avg\n", "Val metric decreasing: False\n", "N epochs: 5\n", - "Optimizer: SGD\n", + "Optimizer: Adam\n", "Lr: 0.001\n", "Weight decay: 0.01\n", "Max grad norm: None\n", - "Optimizer kwargs: {'momentum': 0.9}\n", + "Optimizer kwargs: {}\n", "Scheduler: None\n", "Scheduler kwargs: {}\n", "Scheduler metric split: val\n", @@ -287,7 +308,10 @@ "Use wandb: False\n", "Progress bar: False\n", "Resume: False\n", - "\n" + "\n", + "chr3 2.9614717960357666\n", + "chr2 6.587897777557373\n", + "chr1 10.29332971572876\n" ] } ], @@ -337,13 +361,20 @@ }, { "cell_type": "code", - "execution_count": 5, - "metadata": { - "jupyter": { - "source_hidden": true + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" } - }, - "outputs": [], + ], "source": [ "import copy\n", "full_dataset_camelyon17 = copy.deepcopy(full_dataset)\n", @@ -361,12 +392,11 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 6, "metadata": { "collapsed": true, "jupyter": { - "outputs_hidden": true, - "source_hidden": true + "outputs_hidden": true } }, "outputs": [ @@ -374,36 +404,36 @@ "name": "stdout", "output_type": "stream", "text": [ - "chr3 3.0872416496276855\n", - "chr4 6.014077425003052\n", - "chr5 8.789116859436035\n", - "chr6 11.409600496292114\n", - "chr7 13.844907283782959\n", - "chr10 15.919893264770508\n", - "chr12 17.969276189804077\n", - "chr13 19.71941637992859\n", - "chr14 21.34366464614868\n", - "chr15 22.900768995285034\n", - "chr16 24.27766728401184\n", - "chr17 25.519333600997925\n", - "chr18 26.714667797088623\n", - "chr19 27.614336490631104\n", - "chr20 28.57899522781372\n", - "chr22 29.353068113327026\n", - "chrX 31.731130599975586\n", - "chr2 35.449124813079834\n", - "chr9 37.5920934677124\n", - "chr11 39.65406608581543\n", - "chr1 43.44736051559448\n", - "chr8 45.68234419822693\n", - "chr21 46.41120982170105\n", - "H1-hESC 46.41424226760864\n", - "HCT116 46.41492676734924\n", - "HeLa-S3 46.41563010215759\n", - "HepG2 46.41687893867493\n", - "K562 46.41777992248535\n", - "A549 46.41860294342041\n", - "GM12878 46.41955780982971\n" + "chr3 3.0055365562438965\n", + "chr4 5.905960321426392\n", + "chr5 8.651455879211426\n", + "chr6 11.250766038894653\n", + "chr7 13.660939931869507\n", + "chr10 15.713522672653198\n", + "chr12 17.740623474121094\n", + "chr13 19.478207111358643\n", + "chr14 21.088634252548218\n", + "chr15 22.625713348388672\n", + "chr16 23.987269639968872\n", + "chr17 25.21428894996643\n", + "chr18 26.394341230392456\n", + "chr19 27.28497076034546\n", + "chr20 28.235496282577515\n", + "chr22 28.999913692474365\n", + "chrX 31.338406085968018\n", + "chr2 35.00527381896973\n", + "chr9 37.12277841567993\n", + "chr11 39.157737016677856\n", + "chr1 42.89226841926575\n", + "chr8 45.092690229415894\n", + "chr21 45.81230306625366\n", + "H1-hESC 45.81402635574341\n", + "HCT116 45.814292192459106\n", + "HeLa-S3 45.814526081085205\n", + "HepG2 45.814810276031494\n", + "K562 45.815062522888184\n", + "A549 45.81636619567871\n", + "GM12878 45.81674289703369\n" ] } ], @@ -492,52 +522,8 @@ }, { "cell_type": "code", - "execution_count": 325, - "metadata": { - "jupyter": { - "source_hidden": true - } - }, - "outputs": [], - "source": [ - "def get_random_label_vec(\n", - " metadata_df, seed_chr, seed_celltype, seed_start, output_size=128\n", - "):\n", - " \"\"\"\n", - " Given a coordinate in a celltype, gets the labels of \n", - " the `output_size` 200bp bins from that coordinate onward. \n", - " \"\"\"\n", - " itime = time.time()\n", - " \n", - " # Extract regions from this chromosome in this celltype, to get a window of labels from\n", - " # print(time.time() - itime)\n", - " # chr_msk = np.array(metadata_df['chr']) == seed_region['chr']\n", - " # print(time.time() - itime)\n", - " # ct_msk = np.array(metadata_df['celltype']) == seed_region['celltype']\n", - " # mdf = metadata_df[chr_msk & ct_msk]\n", - " seq_size = output_size*50\n", - " mdf = metadata_df.loc[\n", - " (metadata_df['chr'] == seed_chr) & \n", - " (metadata_df['celltype'] == seed_celltype) & \n", - " (metadata_df['start'] >= seed_start) & \n", - " (metadata_df['stop'] < seed_start+seq_size)\n", - " ]\n", - " print(time.time() - itime)\n", - "\n", - " # Get labels\n", - " y_label_vec = np.zeros(output_size)\n", - " y_label_vec[(mdf['start'] - seed_start) // 50] = mdf['y']\n", - " return mdf, y_label_vec" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": { - "jupyter": { - "source_hidden": true - } - }, + "execution_count": 7, + "metadata": {}, "outputs": [], "source": [ "train_regions_mask = np.isin(_metadata_df['chr'], _train_chroms)\n", @@ -577,6 +563,42 @@ "_metadata_fields = ['chr', 'celltype']" ] }, + { + "cell_type": "code", + "execution_count": 325, + "metadata": {}, + "outputs": [], + "source": [ + "def get_random_label_vec(\n", + " metadata_df, seed_chr, seed_celltype, seed_start, output_size=128\n", + "):\n", + " \"\"\"\n", + " Given a coordinate in a celltype, gets the labels of \n", + " the `output_size` 200bp bins from that coordinate onward. \n", + " \"\"\"\n", + " itime = time.time()\n", + " \n", + " # Extract regions from this chromosome in this celltype, to get a window of labels from\n", + " # print(time.time() - itime)\n", + " # chr_msk = np.array(metadata_df['chr']) == seed_region['chr']\n", + " # print(time.time() - itime)\n", + " # ct_msk = np.array(metadata_df['celltype']) == seed_region['celltype']\n", + " # mdf = metadata_df[chr_msk & ct_msk]\n", + " seq_size = output_size*50\n", + " mdf = metadata_df.loc[\n", + " (metadata_df['chr'] == seed_chr) & \n", + " (metadata_df['celltype'] == seed_celltype) & \n", + " (metadata_df['start'] >= seed_start) & \n", + " (metadata_df['stop'] < seed_start+seq_size)\n", + " ]\n", + " print(time.time() - itime)\n", + "\n", + " # Get labels\n", + " y_label_vec = np.zeros(output_size)\n", + " y_label_vec[(mdf['start'] - seed_start) // 50] = mdf['y']\n", + " return mdf, y_label_vec" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -587,11 +609,7 @@ { "cell_type": "code", "execution_count": 24, - "metadata": { - "jupyter": { - "source_hidden": true - } - }, + "metadata": {}, "outputs": [], "source": [ "import os, time\n", @@ -830,27 +848,6 @@ "# Initialize algorithm" ] }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "ename": "NameError", - "evalue": "name 'algorithm' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0malgorithm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;31mNameError\u001b[0m: name 'algorithm' is not defined" - ] - } - ], - "source": [ - "algorithm.model" - ] - }, { "cell_type": "code", "execution_count": 6, @@ -861,30 +858,38 @@ "output_type": "stream", "text": [ "Train data...\n", - " hospital = 0: n = 53425\n", - " hospital = 1: n = 0\n", - " hospital = 2: n = 0\n", - " hospital = 3: n = 116959\n", - " hospital = 4: n = 132052\n", + " celltype = H1-hESC: n = 5314\n", + " celltype = HCT116: n = 4759\n", + " celltype = HeLa-S3: n = 4635\n", + " celltype = HepG2: n = 4459\n", + " celltype = K562: n = 5169\n", + " celltype = A549: n = 0\n", + " celltype = GM12878: n = 0\n", "Validation (ID) data...\n", - " hospital = 0: n = 6011\n", - " hospital = 1: n = 0\n", - " hospital = 2: n = 0\n", - " hospital = 3: n = 12879\n", - " hospital = 4: n = 14670\n", + " celltype = H1-hESC: n = 6872\n", + " celltype = HCT116: n = 6315\n", + " celltype = HeLa-S3: n = 4219\n", + " celltype = HepG2: n = 8356\n", + " celltype = K562: n = 6538\n", + " celltype = A549: n = 0\n", + " celltype = GM12878: n = 0\n", "Test data...\n", - " hospital = 0: n = 0\n", - " hospital = 1: n = 0\n", - " hospital = 2: n = 85054\n", - " hospital = 3: n = 0\n", - " hospital = 4: n = 0\n", + " celltype = H1-hESC: n = 0\n", + " celltype = HCT116: n = 0\n", + " celltype = HeLa-S3: n = 0\n", + " celltype = HepG2: n = 0\n", + " celltype = K562: n = 0\n", + " celltype = A549: n = 0\n", + " celltype = GM12878: n = 4487\n", "Validation (OOD) data...\n", - " hospital = 0: n = 0\n", - " hospital = 1: n = 34904\n", - " hospital = 2: n = 0\n", - " hospital = 3: n = 0\n", - " hospital = 4: n = 0\n", - "Dout: 2\n" + " celltype = H1-hESC: n = 0\n", + " celltype = HCT116: n = 0\n", + " celltype = HeLa-S3: n = 0\n", + " celltype = HepG2: n = 0\n", + " celltype = K562: n = 0\n", + " celltype = A549: n = 6728\n", + " celltype = GM12878: n = 0\n", + "Dout: 128\n" ] } ], @@ -966,630 +971,237 @@ "cell_type": "code", "execution_count": 7, "metadata": {}, + "outputs": [], + "source": [ + "for batch in datasets['train']['loader']:\n", + " x, y_true, metadata = batch\n", + " break\n", + "# x = torch.transpose(x, 1, 2)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "d = algorithm.process_batch(batch)\n", + "# algorithm.loss.compute" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "DenseNet(\n", - " (features): Sequential(\n", - " (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n", - " (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu0): ReLU(inplace=True)\n", - " (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n", - " (denseblock1): _DenseBlock(\n", - " (denselayer1): _DenseLayer(\n", - " (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer2): _DenseLayer(\n", - " (norm1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(96, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer3): _DenseLayer(\n", - " (norm1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer4): _DenseLayer(\n", - " (norm1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(160, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer5): _DenseLayer(\n", - " (norm1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer6): _DenseLayer(\n", - " (norm1): BatchNorm2d(224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(224, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " )\n", - " (transition1): _Transition(\n", - " (norm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu): ReLU(inplace=True)\n", - " (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n", - " )\n", - " (denseblock2): _DenseBlock(\n", - " (denselayer1): _DenseLayer(\n", - " (norm1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer2): _DenseLayer(\n", - " (norm1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(160, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer3): _DenseLayer(\n", - " (norm1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer4): _DenseLayer(\n", - " (norm1): BatchNorm2d(224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(224, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer5): _DenseLayer(\n", - " (norm1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer6): _DenseLayer(\n", - " (norm1): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(288, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer7): _DenseLayer(\n", - " (norm1): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer8): _DenseLayer(\n", - " (norm1): BatchNorm2d(352, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(352, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer9): _DenseLayer(\n", - " (norm1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(384, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer10): _DenseLayer(\n", - " (norm1): BatchNorm2d(416, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(416, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer11): _DenseLayer(\n", - " (norm1): BatchNorm2d(448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(448, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer12): _DenseLayer(\n", - " (norm1): BatchNorm2d(480, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(480, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " )\n", - " (transition2): _Transition(\n", - " (norm): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu): ReLU(inplace=True)\n", - " (conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n", - " )\n", - " (denseblock3): _DenseBlock(\n", - " (denselayer1): _DenseLayer(\n", - " (norm1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer2): _DenseLayer(\n", - " (norm1): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(288, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer3): _DenseLayer(\n", - " (norm1): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer4): _DenseLayer(\n", - " (norm1): BatchNorm2d(352, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(352, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer5): _DenseLayer(\n", - " (norm1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(384, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer6): _DenseLayer(\n", - " (norm1): BatchNorm2d(416, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(416, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer7): _DenseLayer(\n", - " (norm1): BatchNorm2d(448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(448, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer8): _DenseLayer(\n", - " (norm1): BatchNorm2d(480, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(480, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer9): _DenseLayer(\n", - " (norm1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer10): _DenseLayer(\n", - " (norm1): BatchNorm2d(544, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(544, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer11): _DenseLayer(\n", - " (norm1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer12): _DenseLayer(\n", - " (norm1): BatchNorm2d(608, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer13): _DenseLayer(\n", - " (norm1): BatchNorm2d(640, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(640, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer14): _DenseLayer(\n", - " (norm1): BatchNorm2d(672, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(672, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer15): _DenseLayer(\n", - " (norm1): BatchNorm2d(704, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(704, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer16): _DenseLayer(\n", - " (norm1): BatchNorm2d(736, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(736, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer17): _DenseLayer(\n", - " (norm1): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(768, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer18): _DenseLayer(\n", - " (norm1): BatchNorm2d(800, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(800, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer19): _DenseLayer(\n", - " (norm1): BatchNorm2d(832, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(832, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer20): _DenseLayer(\n", - " (norm1): BatchNorm2d(864, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(864, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer21): _DenseLayer(\n", - " (norm1): BatchNorm2d(896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer22): _DenseLayer(\n", - " (norm1): BatchNorm2d(928, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(928, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer23): _DenseLayer(\n", - " (norm1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(960, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer24): _DenseLayer(\n", - " (norm1): BatchNorm2d(992, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(992, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " )\n", - " (transition3): _Transition(\n", - " (norm): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu): ReLU(inplace=True)\n", - " (conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n", - " )\n", - " (denseblock4): _DenseBlock(\n", - " (denselayer1): _DenseLayer(\n", - " (norm1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer2): _DenseLayer(\n", - " (norm1): BatchNorm2d(544, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(544, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer3): _DenseLayer(\n", - " (norm1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer4): _DenseLayer(\n", - " (norm1): BatchNorm2d(608, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer5): _DenseLayer(\n", - " (norm1): BatchNorm2d(640, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(640, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer6): _DenseLayer(\n", - " (norm1): BatchNorm2d(672, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(672, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer7): _DenseLayer(\n", - " (norm1): BatchNorm2d(704, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(704, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer8): _DenseLayer(\n", - " (norm1): BatchNorm2d(736, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(736, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer9): _DenseLayer(\n", - " (norm1): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(768, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer10): _DenseLayer(\n", - " (norm1): BatchNorm2d(800, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(800, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer11): _DenseLayer(\n", - " (norm1): BatchNorm2d(832, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(832, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer12): _DenseLayer(\n", - " (norm1): BatchNorm2d(864, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(864, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer13): _DenseLayer(\n", - " (norm1): BatchNorm2d(896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer14): _DenseLayer(\n", - " (norm1): BatchNorm2d(928, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(928, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer15): _DenseLayer(\n", - " (norm1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(960, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer16): _DenseLayer(\n", - " (norm1): BatchNorm2d(992, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(992, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " )\n", - " (norm5): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - " (classifier): Linear(in_features=1024, out_features=2, bias=True)\n", - ")" + "tensor(0.7212, device='cuda:0', grad_fn=)" ] }, - "execution_count": 7, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "algorithm.model" + "a = algorithm.loss.compute(d['y_pred'], d['y_true'], return_dict=False)\n", + "a" ] }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 10, "metadata": {}, "outputs": [ { "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
chrstartstopcelltypesplit
39413chr21000320010009600A5493
39414chr2100032000100038400A5493
39415chr2100102400100108800A5493
39416chr2100172800100179200A5493
39417chr2100230400100236800A5493
..................
495287chr3999680010003200K5620
495288chr39997440099980800K5620
495289chr39998080099987200K5620
495290chr39998720099993600K5620
495291chr399993600100000000K5620
\n", + "

67851 rows × 5 columns

\n", + "
" + ], "text/plain": [ - "" + " chr start stop celltype split\n", + "39413 chr2 10003200 10009600 A549 3\n", + "39414 chr2 100032000 100038400 A549 3\n", + "39415 chr2 100102400 100108800 A549 3\n", + "39416 chr2 100172800 100179200 A549 3\n", + "39417 chr2 100230400 100236800 A549 3\n", + "... ... ... ... ... ...\n", + "495287 chr3 9996800 10003200 K562 0\n", + "495288 chr3 99974400 99980800 K562 0\n", + "495289 chr3 99980800 99987200 K562 0\n", + "495290 chr3 99987200 99993600 K562 0\n", + "495291 chr3 99993600 100000000 K562 0\n", + "\n", + "[67851 rows x 5 columns]" ] }, - "execution_count": 29, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "# algorithm.device\n", - "full_dataset\n", - "# datasets['train']['loader']" + "#np.unique(full_dataset._metadata_df['split'], return_counts=True)\n", + "full_dataset._metadata_df" ] }, { "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [], - "source": [ - "for batch in datasets['train']['loader']:\n", - " x, y_true, metadata = batch\n", - " break" - ] - }, - { - "cell_type": "code", - "execution_count": 43, + "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor([1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0,\n", - " 0, 1, 1, 1, 0, 0, 0, 0])" + "(array([0. , 0.5, 1. ], dtype=float32), array([7422683, 1007200, 255045]))" ] }, - "execution_count": 43, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "y_true" + "np.unique(full_dataset.y_array, return_counts=True)" ] }, { "cell_type": "code", - "execution_count": 30, - "metadata": { - "collapsed": true, - "jupyter": { - "outputs_hidden": true - } - }, + "execution_count": 26, + "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor([[ 0.1406, -0.0628],\n", - " [ 0.0534, 0.0359],\n", - " [-0.0174, -0.0097],\n", - " [-0.0571, -0.2381],\n", - " [ 0.1590, -0.0559],\n", - " [ 0.1254, -0.0139],\n", - " [-0.0423, 0.0439],\n", - " [ 0.1621, 0.0730],\n", - " [ 0.0554, 0.0796],\n", - " [-0.0532, 0.0667],\n", - " [-0.1927, -0.0387],\n", - " [ 0.1352, -0.0385],\n", - " [-0.1320, 0.0140],\n", - " [-0.0531, -0.1171],\n", - " [-0.0378, -0.0134],\n", - " [ 0.1047, 0.0298],\n", - " [ 0.0355, -0.0497],\n", - " [ 0.1065, -0.0218],\n", - " [-0.1883, 0.1298],\n", - " [ 0.0699, -0.0875],\n", - " [-0.1233, 0.1793],\n", - " [ 0.0151, 0.0708],\n", - " [-0.0973, -0.0033],\n", - " [ 0.1027, -0.2456],\n", - " [ 0.0433, -0.0441],\n", - " [ 0.1013, -0.1020],\n", - " [ 0.1309, -0.0051],\n", - " [ 0.0028, -0.0558],\n", - " [ 0.0635, 0.0575],\n", - " [-0.0066, 0.0666],\n", - " [-0.0076, -0.0375],\n", - " [ 0.1336, 0.0024]], device='cuda:0', grad_fn=)" + "0.8546625832706961" ] }, - "execution_count": 30, + "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "# datasets['train']['dataset'].size()\n", - "algorithm.model(x.to(algorithm.device))" + "7422683/8684928" ] }, { @@ -1601,9 +1213,40 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Epoch [0]:\n", + "\n", + "Train:\n", + "torch.Size([8192]) torch.Size([8192]) torch.Size([64, 128]) torch.Size([64, 128])\n", + "torch.Size([]) torch.Size([8192]) torch.Size([64, 128]) torch.Size([64, 128])\n" + ] + }, + { + "ename": "AssertionError", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 26\u001b[0;31m train(\n\u001b[0m\u001b[1;32m 27\u001b[0m \u001b[0malgorithm\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0malgorithm\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0mdatasets\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdatasets\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/wilds/examples/train.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(algorithm, datasets, general_logger, config, epoch_offset, best_val_metric)\u001b[0m\n\u001b[1;32m 88\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 89\u001b[0m \u001b[0;31m# First run training\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 90\u001b[0;31m \u001b[0mrun_epoch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0malgorithm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdatasets\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'train'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgeneral_logger\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepoch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 91\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[0;31m# Then run val\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/wilds/examples/train.py\u001b[0m in \u001b[0;36mrun_epoch\u001b[0;34m(algorithm, dataset, general_logger, epoch, config, train)\u001b[0m\n\u001b[1;32m 41\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0miterator\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 43\u001b[0;31m \u001b[0mbatch_results\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0malgorithm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 44\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[0mbatch_results\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0malgorithm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mevaluate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/wilds/examples/algorithms/single_model_algorithm.py\u001b[0m in \u001b[0;36mupdate\u001b[0;34m(self, batch)\u001b[0m\n\u001b[1;32m 103\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_update\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 104\u001b[0m \u001b[0;31m# log results\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 105\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate_log\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 106\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msanitize_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 107\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/wilds/examples/algorithms/group_algorithm.py\u001b[0m in \u001b[0;36mupdate_log\u001b[0;34m(self, results)\u001b[0m\n\u001b[1;32m 47\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mm\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlogged_metrics\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mno_group_logging\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 49\u001b[0;31m group_metrics, group_counts, worst_group_metric = m.compute_group_wise(\n\u001b[0m\u001b[1;32m 50\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'y_pred'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'y_true'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/wilds/wilds/common/metrics/metric.py\u001b[0m in \u001b[0;36mcompute_group_wise\u001b[0;34m(self, y_pred, y_true, g, n_groups, return_dict)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mresults\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mdict\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mDictionary\u001b[0m \u001b[0mof\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 113\u001b[0m \"\"\"\n\u001b[0;32m--> 114\u001b[0;31m \u001b[0mgroup_metrics\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgroup_counts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mworst_group_metric\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compute_group_wise\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_groups\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 115\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mreturn_dict\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/wilds/wilds/common/metrics/metric.py\u001b[0m in \u001b[0;36m_compute_group_wise\u001b[0;34m(self, y_pred, y_true, g, n_groups)\u001b[0m\n\u001b[1;32m 234\u001b[0m \u001b[0mflattened_g\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mg\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mindices\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 235\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mflattened_metrics\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mflattened_g\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_pred\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 236\u001b[0;31m \u001b[0mgroup_metrics\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgroup_counts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mavg_over_groups\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mflattened_metrics\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mflattened_g\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_groups\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 237\u001b[0m \u001b[0mworst_group_metric\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mworst\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgroup_metrics\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mgroup_counts\u001b[0m\u001b[0;34m>\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 238\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mgroup_metrics\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgroup_counts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mworst_group_metric\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/wilds/wilds/common/utils.py\u001b[0m in \u001b[0;36mavg_over_groups\u001b[0;34m(v, g, n_groups)\u001b[0m\n\u001b[1;32m 84\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0mg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 85\u001b[0m \u001b[0mdevice\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 86\u001b[0;31m \u001b[0;32massert\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0mg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 87\u001b[0m \u001b[0mgroup_count\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_counts\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_groups\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 88\u001b[0m \u001b[0mgroup_avgs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch_scatter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscatter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msrc\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindex\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mn_groups\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduce\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'mean'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mAssertionError\u001b[0m: " + ] + } + ], "source": [ "if not config.eval_only:\n", " ## Load saved results if resuming\n", @@ -1681,867 +1324,6 @@ "outputs": [], "source": [] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 33, - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "\n", - "class Beagle(nn.Module):\n", - " \"\"\"\n", - " Neural net models over genomic sequence.\n", - " Input:\n", - " - sequence_length: int (default 1000) \n", - " - Shape: (N, 5, sequence_length, 1) with batch size N.\n", - " \n", - " Output:\n", - " - prediction (Tensor): float torch tensor of shape (N, )\n", - " \n", - " TODO: Finish docstring.\n", - " \"\"\"\n", - " def __init__(self):\n", - " \"\"\"\n", - " Parameters\n", - " ----------\n", - " sequence_length : int\n", - " n_genomic_features : int\n", - " \"\"\"\n", - " super(Beagle, self).__init__()\n", - "\n", - " self.dropout = 0.3\n", - " self.num_cell_types = 1\n", - " self.conv1 = nn.Conv2d(5, 50, (19, 1), stride = (1, 1), padding=(9,0))\n", - " self.conv2 = nn.Conv2d(50, 50, (11, 1), stride = (1, 1), padding = (5,0))\n", - " self.conv3 = nn.Conv2d(50, 50, (7, 1), stride = (1, 1), padding = (4,0))\n", - " self.bn1 = nn.BatchNorm2d(50)\n", - " self.bn2 = nn.BatchNorm2d(50)\n", - " self.bn3 = nn.BatchNorm2d(50)\n", - " self.maxpool1 = nn.MaxPool2d((3, 1))\n", - " self.maxpool2 = nn.MaxPool2d((4, 1))\n", - " self.maxpool3 = nn.MaxPool2d((4, 1))\n", - "\n", - " self.fc1 = nn.Linear(4200, 1000)\n", - " self.bn4 = nn.BatchNorm1d(1000)\n", - "\n", - " self.fc2 = nn.Linear(1000, 1000)\n", - " self.bn5 = nn.BatchNorm1d(1000)\n", - "\n", - " self.fc3 = nn.Linear(1000, self.num_cell_types)\n", - "\n", - " def forward(self, s):\n", - " s = s.permute(0, 2, 1).contiguous() # batch_size x 5 x 1000\n", - " s = s.view(-1, 5, 1000, 1) # batch_size x 5 x 1000 x 1 [5 channels]\n", - " s = self.maxpool1(F.relu(self.bn1(self.conv1(s)))) # batch_size x 300 x 333 x 1\n", - " s = self.maxpool2(F.relu(self.bn2(self.conv2(s)))) # batch_size x 200 x 83 x 1\n", - " s = self.maxpool3(F.relu(self.bn3(self.conv3(s)))) # batch_size x 200 x 21 x 1\n", - " s = s.view(-1, 4200)\n", - " conv_out = s\n", - "\n", - " s = F.dropout(F.relu(self.bn4(self.fc1(s))), p=self.dropout, training=self.training) # batch_size x 1000\n", - " s = F.dropout(F.relu(self.bn5(self.fc2(s))), p=self.dropout, training=self.training) # batch_size x 1000\n", - " \n", - " s = self.fc3(s)\n", - "\n", - " return s, conv_out" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [], - "source": [ - "import math\n", - "import numpy as np\n", - "import torch\n", - "import torch.nn as nn\n", - "import torch.nn.functional as F\n", - "\n", - "\n", - "def double_conv(in_channels, out_channels): \n", - " return nn.Sequential(\n", - " nn.Conv1d(in_channels, out_channels, 7, padding=2), \n", - " nn.BatchNorm1d(out_channels), \n", - " nn.ReLU(inplace=True),\n", - " nn.Conv1d(out_channels, out_channels, 7, padding=3), \n", - " nn.BatchNorm1d(out_channels), \n", - " nn.ReLU(inplace=True)\n", - " )\n", - "\n", - "\n", - "class UNet(nn.Module):\n", - "\n", - " def __init__(self, n_class):\n", - " super().__init__()\n", - " \n", - " self.dconv_down1 = double_conv(6, 15)\n", - " self.dconv_down2 = double_conv(15, 22)\n", - " self.dconv_down3 = double_conv(22, 33)\n", - " self.dconv_down4 = double_conv(33, 49)\n", - " self.dconv_down5 = double_conv(49, 73)\n", - " self.dconv_down6 = double_conv(73, 109)\n", - "\n", - " self.maxpool = nn.MaxPool1d(2)\n", - " self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) \n", - " \n", - " self.dconv_up5 = double_conv(73 + 109, 73)\n", - " self.dconv_up4 = double_conv(49 + 73, 49)\n", - " self.dconv_up3 = double_conv(33 + 49, 33)\n", - " self.dconv_up2 = double_conv(22 + 33, 22)\n", - " self.dconv_up1 = double_conv(15 + 22, 15)\n", - " \n", - " self.conv_last = nn.Conv2d(15, n_class, 1)\n", - " \n", - " \n", - " def forward(self, x):\n", - " conv1 = self.dconv_down1(x)\n", - " x = self.maxpool(conv1)\n", - "\n", - " conv2 = self.dconv_down2(x)\n", - " x = self.maxpool(conv2)\n", - " \n", - " conv3 = self.dconv_down3(x)\n", - " x = self.maxpool(conv3)\n", - " \n", - " conv4 = self.dconv_down4(x)\n", - " x = self.maxpool(conv4)\n", - " \n", - " conv5 = self.dconv_down5(x)\n", - " x = self.maxpool(conv5)\n", - " \n", - " x = self.dconv_down6(x)\n", - " \n", - " x = self.upsample(x) \n", - " x = torch.cat([x, conv5], dim=1)\n", - " \n", - " x = self.dconv_up5(x)\n", - " x = self.upsample(x) \n", - " x = torch.cat([x, conv4], dim=1)\n", - " \n", - " x = self.dconv_up4(x)\n", - " x = self.upsample(x) \n", - " x = torch.cat([x, conv3], dim=1)\n", - " \n", - " x = self.dconv_up3(x)\n", - " x = self.upsample(x) \n", - " x = torch.cat([x, conv2], dim=1) \n", - "\n", - " x = self.dconv_up2(x)\n", - " x = self.upsample(x) \n", - " x = torch.cat([x, conv1], dim=1) \n", - " \n", - " x = self.dconv_up1(x)\n", - " \n", - " out = self.conv_last(x)\n", - " \n", - " return out" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": { - "collapsed": true, - "jupyter": { - "outputs_hidden": true - } - }, - "outputs": [ - { - "data": { - "text/plain": [ - "UNet(\n", - " (dconv_down1): Sequential(\n", - " (0): Conv1d(6, 15, kernel_size=(7,), stride=(1,), padding=(2,))\n", - " (1): BatchNorm1d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (2): ReLU(inplace=True)\n", - " (3): Conv1d(15, 15, kernel_size=(7,), stride=(1,), padding=(3,))\n", - " (4): BatchNorm1d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (5): ReLU(inplace=True)\n", - " )\n", - " (dconv_down2): Sequential(\n", - " (0): Conv1d(15, 22, kernel_size=(7,), stride=(1,), padding=(2,))\n", - " (1): BatchNorm1d(22, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (2): ReLU(inplace=True)\n", - " (3): Conv1d(22, 22, kernel_size=(7,), stride=(1,), padding=(3,))\n", - " (4): BatchNorm1d(22, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (5): ReLU(inplace=True)\n", - " )\n", - " (dconv_down3): Sequential(\n", - " (0): Conv1d(22, 33, kernel_size=(7,), stride=(1,), padding=(2,))\n", - " (1): BatchNorm1d(33, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (2): ReLU(inplace=True)\n", - " (3): Conv1d(33, 33, kernel_size=(7,), stride=(1,), padding=(3,))\n", - " (4): BatchNorm1d(33, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (5): ReLU(inplace=True)\n", - " )\n", - " (dconv_down4): Sequential(\n", - " (0): Conv1d(33, 49, kernel_size=(7,), stride=(1,), padding=(2,))\n", - " (1): BatchNorm1d(49, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (2): ReLU(inplace=True)\n", - " (3): Conv1d(49, 49, kernel_size=(7,), stride=(1,), padding=(3,))\n", - " (4): BatchNorm1d(49, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (5): ReLU(inplace=True)\n", - " )\n", - " (dconv_down5): Sequential(\n", - " (0): Conv1d(49, 73, kernel_size=(7,), stride=(1,), padding=(2,))\n", - " (1): BatchNorm1d(73, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (2): ReLU(inplace=True)\n", - " (3): Conv1d(73, 73, kernel_size=(7,), stride=(1,), padding=(3,))\n", - " (4): BatchNorm1d(73, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (5): ReLU(inplace=True)\n", - " )\n", - " (dconv_down6): Sequential(\n", - " (0): Conv1d(73, 109, kernel_size=(7,), stride=(1,), padding=(2,))\n", - " (1): BatchNorm1d(109, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (2): ReLU(inplace=True)\n", - " (3): Conv1d(109, 109, kernel_size=(7,), stride=(1,), padding=(3,))\n", - " (4): BatchNorm1d(109, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (5): ReLU(inplace=True)\n", - " )\n", - " (maxpool): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", - " (upsample): Upsample(scale_factor=2.0, mode=bilinear)\n", - " (dconv_up5): Sequential(\n", - " (0): Conv1d(182, 73, kernel_size=(7,), stride=(1,), padding=(2,))\n", - " (1): BatchNorm1d(73, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (2): ReLU(inplace=True)\n", - " (3): Conv1d(73, 73, kernel_size=(7,), stride=(1,), padding=(3,))\n", - " (4): BatchNorm1d(73, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (5): ReLU(inplace=True)\n", - " )\n", - " (dconv_up4): Sequential(\n", - " (0): Conv1d(122, 49, kernel_size=(7,), stride=(1,), padding=(2,))\n", - " (1): BatchNorm1d(49, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (2): ReLU(inplace=True)\n", - " (3): Conv1d(49, 49, kernel_size=(7,), stride=(1,), padding=(3,))\n", - " (4): BatchNorm1d(49, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (5): ReLU(inplace=True)\n", - " )\n", - " (dconv_up3): Sequential(\n", - " (0): Conv1d(82, 33, kernel_size=(7,), stride=(1,), padding=(2,))\n", - " (1): BatchNorm1d(33, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (2): ReLU(inplace=True)\n", - " (3): Conv1d(33, 33, kernel_size=(7,), stride=(1,), padding=(3,))\n", - " (4): BatchNorm1d(33, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (5): ReLU(inplace=True)\n", - " )\n", - " (dconv_up2): Sequential(\n", - " (0): Conv1d(55, 22, kernel_size=(7,), stride=(1,), padding=(2,))\n", - " (1): BatchNorm1d(22, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (2): ReLU(inplace=True)\n", - " (3): Conv1d(22, 22, kernel_size=(7,), stride=(1,), padding=(3,))\n", - " (4): BatchNorm1d(22, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (5): ReLU(inplace=True)\n", - " )\n", - " (dconv_up1): Sequential(\n", - " (0): Conv1d(37, 15, kernel_size=(7,), stride=(1,), padding=(2,))\n", - " (1): BatchNorm1d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (2): ReLU(inplace=True)\n", - " (3): Conv1d(15, 15, kernel_size=(7,), stride=(1,), padding=(3,))\n", - " (4): BatchNorm1d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (5): ReLU(inplace=True)\n", - " )\n", - " (conv_last): Conv2d(15, 2, kernel_size=(1, 1), stride=(1, 1))\n", - ")" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "model = UNet(2)\n", - "model" - ] - }, - { - "cell_type": "code", - "execution_count": 101, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "485773" - ] - }, - "execution_count": 101, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "def count_parameters(model):\n", - " return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", - "\n", - "lst = [(x[0], x[1].numel()) for x in model.named_parameters()]\n", - "#np.sum([x[1] for x in lst])\n", - "count_parameters(model)" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "6955906" - ] - }, - "execution_count": 25, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "count_parameters(algorithm.model)" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": {}, - "outputs": [], - "source": [ - "lst = [(x[0], x[1].numel()) for x in algorithm.model.named_parameters()]" - ] - }, - { - "cell_type": "code", - "execution_count": 39, - "metadata": { - "collapsed": true, - "jupyter": { - "outputs_hidden": true - } - }, - "outputs": [ - { - "data": { - "text/plain": [ - "DenseNet(\n", - " (features): Sequential(\n", - " (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n", - " (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu0): ReLU(inplace=True)\n", - " (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n", - " (denseblock1): _DenseBlock(\n", - " (denselayer1): _DenseLayer(\n", - " (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer2): _DenseLayer(\n", - " (norm1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(96, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer3): _DenseLayer(\n", - " (norm1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer4): _DenseLayer(\n", - " (norm1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(160, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer5): _DenseLayer(\n", - " (norm1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer6): _DenseLayer(\n", - " (norm1): BatchNorm2d(224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(224, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " )\n", - " (transition1): _Transition(\n", - " (norm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu): ReLU(inplace=True)\n", - " (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n", - " )\n", - " (denseblock2): _DenseBlock(\n", - " (denselayer1): _DenseLayer(\n", - " (norm1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer2): _DenseLayer(\n", - " (norm1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(160, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer3): _DenseLayer(\n", - " (norm1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer4): _DenseLayer(\n", - " (norm1): BatchNorm2d(224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(224, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer5): _DenseLayer(\n", - " (norm1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer6): _DenseLayer(\n", - " (norm1): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(288, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer7): _DenseLayer(\n", - " (norm1): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer8): _DenseLayer(\n", - " (norm1): BatchNorm2d(352, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(352, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer9): _DenseLayer(\n", - " (norm1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(384, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer10): _DenseLayer(\n", - " (norm1): BatchNorm2d(416, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(416, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer11): _DenseLayer(\n", - " (norm1): BatchNorm2d(448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(448, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer12): _DenseLayer(\n", - " (norm1): BatchNorm2d(480, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(480, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " )\n", - " (transition2): _Transition(\n", - " (norm): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu): ReLU(inplace=True)\n", - " (conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n", - " )\n", - " (denseblock3): _DenseBlock(\n", - " (denselayer1): _DenseLayer(\n", - " (norm1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer2): _DenseLayer(\n", - " (norm1): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(288, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer3): _DenseLayer(\n", - " (norm1): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer4): _DenseLayer(\n", - " (norm1): BatchNorm2d(352, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(352, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer5): _DenseLayer(\n", - " (norm1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(384, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer6): _DenseLayer(\n", - " (norm1): BatchNorm2d(416, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(416, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer7): _DenseLayer(\n", - " (norm1): BatchNorm2d(448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(448, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer8): _DenseLayer(\n", - " (norm1): BatchNorm2d(480, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(480, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer9): _DenseLayer(\n", - " (norm1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer10): _DenseLayer(\n", - " (norm1): BatchNorm2d(544, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(544, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer11): _DenseLayer(\n", - " (norm1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer12): _DenseLayer(\n", - " (norm1): BatchNorm2d(608, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer13): _DenseLayer(\n", - " (norm1): BatchNorm2d(640, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(640, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer14): _DenseLayer(\n", - " (norm1): BatchNorm2d(672, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(672, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer15): _DenseLayer(\n", - " (norm1): BatchNorm2d(704, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(704, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer16): _DenseLayer(\n", - " (norm1): BatchNorm2d(736, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(736, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer17): _DenseLayer(\n", - " (norm1): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(768, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer18): _DenseLayer(\n", - " (norm1): BatchNorm2d(800, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(800, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer19): _DenseLayer(\n", - " (norm1): BatchNorm2d(832, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(832, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer20): _DenseLayer(\n", - " (norm1): BatchNorm2d(864, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(864, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer21): _DenseLayer(\n", - " (norm1): BatchNorm2d(896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer22): _DenseLayer(\n", - " (norm1): BatchNorm2d(928, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(928, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer23): _DenseLayer(\n", - " (norm1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(960, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer24): _DenseLayer(\n", - " (norm1): BatchNorm2d(992, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(992, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " )\n", - " (transition3): _Transition(\n", - " (norm): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu): ReLU(inplace=True)\n", - " (conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n", - " )\n", - " (denseblock4): _DenseBlock(\n", - " (denselayer1): _DenseLayer(\n", - " (norm1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer2): _DenseLayer(\n", - " (norm1): BatchNorm2d(544, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(544, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer3): _DenseLayer(\n", - " (norm1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer4): _DenseLayer(\n", - " (norm1): BatchNorm2d(608, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer5): _DenseLayer(\n", - " (norm1): BatchNorm2d(640, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(640, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer6): _DenseLayer(\n", - " (norm1): BatchNorm2d(672, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(672, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer7): _DenseLayer(\n", - " (norm1): BatchNorm2d(704, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(704, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer8): _DenseLayer(\n", - " (norm1): BatchNorm2d(736, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(736, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer9): _DenseLayer(\n", - " (norm1): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(768, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer10): _DenseLayer(\n", - " (norm1): BatchNorm2d(800, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(800, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer11): _DenseLayer(\n", - " (norm1): BatchNorm2d(832, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(832, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer12): _DenseLayer(\n", - " (norm1): BatchNorm2d(864, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(864, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer13): _DenseLayer(\n", - " (norm1): BatchNorm2d(896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer14): _DenseLayer(\n", - " (norm1): BatchNorm2d(928, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(928, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer15): _DenseLayer(\n", - " (norm1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(960, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer16): _DenseLayer(\n", - " (norm1): BatchNorm2d(992, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(992, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " )\n", - " (norm5): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - " (classifier): Linear(in_features=1024, out_features=2, bias=True)\n", - ")" - ] - }, - "execution_count": 39, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "algorithm.model" - ] - }, { "cell_type": "code", "execution_count": null, diff --git a/wilds/common/metrics/all_metrics.py b/wilds/common/metrics/all_metrics.py index 3330243e..0d84f1ad 100644 --- a/wilds/common/metrics/all_metrics.py +++ b/wilds/common/metrics/all_metrics.py @@ -61,6 +61,30 @@ def _compute_flattened(self, flattened_y_pred, flattened_y_true): def worst(self, metrics): return minimum(metrics) +class MultiTaskAveragePrecision(MultiTaskMetric): + def __init__(self, prediction_fn=logits_to_binary_pred, name=None, average='macro'): + self.prediction_fn = prediction_fn + if name is None: + name = f'avgprec' + if average is not None: + name+=f'-{average}' + self.average = average + super().__init__(name=name) + + def _compute_flattened(self, flattened_y_pred, flattened_y_true): + if self.prediction_fn is not None: + flattened_y_pred = self.prediction_fn(flattened_y_pred) + score = sklearn.metrics.average_precision_score( + np.array(flattened_y_true.squeeze().detach().cpu().numpy() > 0), + flattened_y_pred.squeeze().detach().cpu().numpy(), + average=self.average + ) + return torch.tensor(score).to(flattened_y_pred.device) + + def worst(self, metrics): + return minimum(metrics) + + class Recall(Metric): def __init__(self, prediction_fn=logits_to_pred, name=None, average='binary'): self.prediction_fn = prediction_fn @@ -93,7 +117,11 @@ def __init__(self, prediction_fn=logits_to_pred, name=None, average='macro'): def _compute(self, y_pred, y_true): if self.prediction_fn is not None: y_pred = self.prediction_fn(y_pred) - score = sklearn.metrics.average_precision_score(y_true, y_pred, average=self.average, labels=torch.unique(y_true)) + score = sklearn.metrics.average_precision_score( + np.array(y_true.squeeze().detach().cpu().numpy() > 0), + y_pred.squeeze().detach().cpu().numpy(), + average=self.average + ) return torch.tensor(score) def worst(self, metrics): diff --git a/wilds/common/metrics/metric.py b/wilds/common/metrics/metric.py index 4c3e8440..281696d8 100644 --- a/wilds/common/metrics/metric.py +++ b/wilds/common/metrics/metric.py @@ -232,6 +232,7 @@ def _compute(self, y_pred, y_true): def _compute_group_wise(self, y_pred, y_true, g, n_groups): flattened_metrics, indices = self.compute_flattened(y_pred, y_true, return_dict=False) flattened_g = g[indices] + print(flattened_metrics.shape, flattened_g.shape, y_pred.shape, y_true.shape) group_metrics, group_counts = avg_over_groups(flattened_metrics, flattened_g, n_groups) worst_group_metric = self.worst(group_metrics[group_counts>0]) return group_metrics, group_counts, worst_group_metric diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index 23b70014..588b9fce 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -15,7 +15,7 @@ class EncodeTFBSDataset(WILDSDataset): This is a subset of the dataset from the ENCODE-DREAM in vivo Transcription Factor Binding Site Prediction Challenge. Input (x): - 1000-base-pair regions of sequence with a quantified chromatin accessibility readout. + 12800-base-pair regions of sequence with a quantified chromatin accessibility readout. Label (y): y is binary. It is 1 if the central 200bp region is bound by the transcription factor MAX, and 0 otherwise. @@ -36,9 +36,9 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): self._y_size = 128 # self._n_classes = 2 - self._train_chroms = ['chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr10', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX'] - self._val_chroms = ['chr2', 'chr9', 'chr11'] - self._test_chroms = ['chr1', 'chr8', 'chr21'] + self._train_chroms = ['chr3']#, 'chr4', 'chr5', 'chr6', 'chr7', 'chr10', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX'] + self._val_chroms = ['chr2']#, 'chr9', 'chr11'] + self._test_chroms = ['chr1']#, 'chr8', 'chr21'] self._transcription_factor = 'MAX' self._train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562'] self._val_celltype = ['A549'] @@ -165,15 +165,18 @@ def get_input(self, idx, window_size=12800): (4) Window_size, the length of sequence returned (centered on the 6400bp region in (3)) """ this_metadata = self._metadata_df.iloc[idx, :] + chrom = this_metadata['chr'] interval_start = this_metadata['start'] - int(window_size/4) interval_end = interval_start + window_size #this_metadata['stop'] seq_this = self._seq_bp[this_metadata['chr']][interval_start:interval_end] dnase_bw = self._dnase_allcelltypes[this_metadata['celltype']] dnase_this = dnase_bw.values(chrom, interval_start, interval_end, numpy=True) + # print("{}:{}-{}".format(chrom, interval_start, interval_end)) dnase_avg = self._dnase_allcelltypes['avg'].values(chrom, interval_start, interval_end, numpy=True) return torch.tensor(np.column_stack( - [np.nan_to_num(seq_this), np.nan_to_num(dnase_this), np.nan_to_num(dnase_avg)] - )) + [np.nan_to_num(seq_this), + np.nan_to_num(dnase_this), np.nan_to_num(dnase_avg)] + ).T) def eval(self, y_pred, y_true, metadata): return self.standard_group_eval( From ee708a6eb1f07606e778ee47a8aa2ec5bcc0a7ec Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Fri, 12 Mar 2021 09:37:06 -0800 Subject: [PATCH 030/244] initial commit from gwhd fork --- wilds/__init__.py | 1 + wilds/datasets/gwhd_dataset.py | 130 +++++++++++++++++++++++++++++++++ wilds/get_dataset.py | 6 +- 3 files changed, 136 insertions(+), 1 deletion(-) create mode 100644 wilds/datasets/gwhd_dataset.py diff --git a/wilds/__init__.py b/wilds/__init__.py index 77f0ad5a..fe1bed0a 100644 --- a/wilds/__init__.py +++ b/wilds/__init__.py @@ -10,6 +10,7 @@ 'poverty', 'fmow', 'py150', + 'gwhd', ] additional_datasets = [ diff --git a/wilds/datasets/gwhd_dataset.py b/wilds/datasets/gwhd_dataset.py new file mode 100644 index 00000000..064be8ff --- /dev/null +++ b/wilds/datasets/gwhd_dataset.py @@ -0,0 +1,130 @@ +import numpy as np +import pandas as pd +import torch +from pathlib import Path +from PIL import Image +#from wilds.common.metrics.all_metrics import MultiTaskAccuracy +from wilds.datasets.wilds_dataset import WILDSDataset + + +class GWHDDataset(WILDSDataset): + """ + The GWHD-wilds wheat head localization dataset. + This is a modified version of the original Global Wheat Head Dataset. + This dataset is not part of the official WILDS benchmark. + We provide it for convenience and to reproduce observations discussed in the WILDS paper. + Supported `split_scheme`: + 'official' for WILDS related tasks. + To reproduce the baseline, several splits are needed: + - to train a model on train domains and test against a all test split: 'train_in-dist' + - to train a model on a portion of a specific val or test domain and test it against the remaining portion: + "{domain}_in-dist" where domain is the id of a domain (usask_1, uq_1, utokyo_1, utokyo_2, nau_1) + no validation datasets are accessible for the baseline splits + Input (x): + 1024x1024 RGB images of wheat field canopy between flowering and ripening. + Output (y): + y is a nx4-dimensional vector where each line represents a box coordinate (top-x,top-y,height,width) + Metadata: + Each image is annotated with the ID of the domain it came from (integer from 0 to 10). + Website: + http://www.global-wheat.com/ + Original publication: + @article{david_global_2020, + title = {Global {Wheat} {Head} {Detection} ({GWHD}) {Dataset}: {A} {Large} and {Diverse} {Dataset} of {High}-{Resolution} {RGB}-{Labelled} {Images} to {Develop} and {Benchmark} {Wheat} {Head} {Detection} {Methods}}, + volume = {2020}, + url = {https://doi.org/10.34133/2020/3521852}, + doi = {10.34133/2020/3521852}, + journal = {Plant Phenomics}, + author = {David, Etienne and Madec, Simon and Sadeghi-Tehran, Pouria and Aasen, Helge and Zheng, Bangyou and Liu, Shouyang and Kirchgessner, Norbert and Ishikawa, Goro and Nagasawa, Koichi and Badhon, Minhajul A. and Pozniak, Curtis and de Solan, Benoit and Hund, Andreas and Chapman, Scott C. and Baret, Frédéric and Stavness, Ian and Guo, Wei}, + month = aug, + year = {2020}, + note = {Publisher: AAAS}, + pages = {3521852}, + } + License: + This dataset is distributed under the MIT license. + https://github.com/snap-stanford/ogb/blob/master/LICENSE + """ + + _dataset_name = 'gwhd' + _versions_dict = { + '1.0': { + 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x42fa9775eacc453489a428abd59a437d/contents/blob/', + 'compressed_size': None}} + + def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'): + + self._version = version + self._data_dir = self.initialize_data_dir(root_dir, download) + self._original_resolution = (1024, 1024) + self.root = Path(self.data_dir) + + self._split_scheme = split_scheme + + # Get filenames + + if split_scheme =="official": + train_data_df = pd.read_csv(self.root / f'{split_scheme}_train.csv') + val_data_df = pd.read_csv(self.root / f'{split_scheme}_val.csv') + test_data_df = pd.read_csv(self.root / f'{split_scheme}_test.csv') + + elif split_scheme == "train_in-dist": + train_data_df = pd.read_csv(self.root / f'official_train.csv') + test_data_df = pd.read_csv(self.root / f'{split_scheme}_test.csv') + val_data_df = pd.DataFrame(columns=["image","labels","group"]) + elif split_scheme in [f"{domain}_in-dist" for domain in ["nau_1", "utokyo_1", "utokyo_2", "usask_1" , "uq_1"]]: + train_data_df = pd.read_csv(self.root / f'{split_scheme}_train.csv') + test_data_df = pd.read_csv(self.root / f'{split_scheme}_test.csv') + val_data_df = pd.DataFrame(columns=["image","labels","group"]) + + elif split_scheme == "in-dist": + train_data_df = pd.read_csv(self.root / f'{split_scheme}_train.csv') + test_data_df = pd.read_csv(self.root / f'{split_scheme}_test.csv') + val_data_df = pd.DataFrame(columns=["image","labels","group"]) + + + self._image_array = [] + self._split_array, self._y_array, self._metadata_array = [], [], [] + + for i, df in enumerate([train_data_df, val_data_df, test_data_df]): + self._image_array.extend(list(df['image'].values)) + labels = list(df['labels'].values) + self._split_array.extend([i] * len(labels)) + + + + labels = [{"boxes": torch.stack([ torch.tensor([int(i) for i in box.split(" ")]) for box in boxes.split(";")]) ,"labels": torch.tensor([1.]*len(list(boxes.split(";")))).long() } if type(boxes) != float else {"boxes":torch.empty(0,4),"labels":torch.empty(0,1,dtype=torch.long)} for boxes in labels] + + self._y_array.extend(labels) + + + self._metadata_array.extend(list(df['group'].values)) + + + self._y_size = 1 + + self._metadata_fields = ["domain"] + + self._split_array = np.array(self._split_array) + + + + + self._metadata_array = torch.tensor(self._metadata_array, + dtype=torch.long).unsqueeze(1) + + #self._metric = MultiTaskAccuracy() + + def get_input(self, idx): + """ + Returns x for a given idx. + """ + img_filename = self.root / "images" / self._image_array[idx] + x = Image.open(img_filename) + return x + + def eval(self, y_pred, y_true, metadata): + return self.standard_group_eval( + self._metric, + self._eval_grouper, + y_pred, y_true, metadata) diff --git a/wilds/get_dataset.py b/wilds/get_dataset.py index 1073100f..cfa5f2c7 100644 --- a/wilds/get_dataset.py +++ b/wilds/get_dataset.py @@ -55,7 +55,7 @@ def get_dataset(dataset, version=None, **dataset_kwargs): elif dataset == 'poverty': if version == '1.0': from wilds.datasets.archive.poverty_v1_0_dataset import PovertyMapDataset - else: + else: from wilds.datasets.poverty_dataset import PovertyMapDataset return PovertyMapDataset(version=version, **dataset_kwargs) @@ -77,3 +77,7 @@ def get_dataset(dataset, version=None, **dataset_kwargs): elif dataset == 'sqf': from wilds.datasets.sqf_dataset import SQFDataset return SQFDataset(version=version, **dataset_kwargs) + + elif dataset == 'gwhd': + from wilds.datasets.gwhd_dataset import GWHDDataset + return GWHDDataset(version=version, **dataset_kwargs) From 119f3a401dc73163e50a045879e1f5acb5b4f39a Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Fri, 12 Mar 2021 23:20:43 -0800 Subject: [PATCH 031/244] detr support up til algorithm.update() --- examples/algorithms/group_algorithm.py | 4 +- examples/algorithms/initializer.py | 29 +- examples/algorithms/single_model_algorithm.py | 7 +- examples/configs/datasets.py | 23 + examples/configs/model.py | 35 +- examples/configs/supported.py | 30 +- examples/losses.py | 47 ++ examples/models/detr/README.md | 1 + examples/models/detr/__init__.py | 6 + examples/models/detr/backbone.py | 119 +++++ examples/models/detr/detr.py | 338 +++++++++++++ examples/models/detr/matcher.py | 86 ++++ examples/models/detr/position_encoding.py | 89 ++++ examples/models/detr/transformer.py | 296 +++++++++++ examples/models/detr/util/__init__.py | 1 + examples/models/detr/util/box_ops.py | 88 ++++ examples/models/detr/util/misc.py | 467 ++++++++++++++++++ examples/models/detr/util/plot_utils.py | 107 ++++ examples/models/initializer.py | 72 ++- examples/run_expt.py | 2 + examples/utils.py | 11 + wilds/common/metrics/all_metrics.py | 15 +- wilds/common/metrics/metric.py | 4 +- wilds/common/utils.py | 9 +- wilds/datasets/gwhd_dataset.py | 60 ++- wilds/datasets/wilds_dataset.py | 12 +- 26 files changed, 1895 insertions(+), 63 deletions(-) create mode 100644 examples/losses.py create mode 100644 examples/models/detr/README.md create mode 100644 examples/models/detr/__init__.py create mode 100644 examples/models/detr/backbone.py create mode 100644 examples/models/detr/detr.py create mode 100644 examples/models/detr/matcher.py create mode 100644 examples/models/detr/position_encoding.py create mode 100644 examples/models/detr/transformer.py create mode 100644 examples/models/detr/util/__init__.py create mode 100644 examples/models/detr/util/box_ops.py create mode 100644 examples/models/detr/util/misc.py create mode 100644 examples/models/detr/util/plot_utils.py diff --git a/examples/algorithms/group_algorithm.py b/examples/algorithms/group_algorithm.py index 54cac1a8..eb0b95c2 100644 --- a/examples/algorithms/group_algorithm.py +++ b/examples/algorithms/group_algorithm.py @@ -3,7 +3,7 @@ from algorithms.algorithm import Algorithm from utils import update_average from scheduler import step_scheduler -from wilds.common.utils import get_counts +from wilds.common.utils import get_counts, numel class GroupAlgorithm(Algorithm): """ @@ -57,7 +57,7 @@ def update_log(self, results): results['y_pred'], results['y_true'], return_dict=False).item() - count = results['y_true'].numel() + count = numel(results['y_true']) # transfer other statistics in the results dictionary for field in self.logged_fields: diff --git a/examples/algorithms/initializer.py b/examples/algorithms/initializer.py index 00748cfc..180e9ff5 100644 --- a/examples/algorithms/initializer.py +++ b/examples/algorithms/initializer.py @@ -3,7 +3,8 @@ from algorithms.groupDRO import GroupDRO from algorithms.deepCORAL import DeepCORAL from algorithms.IRM import IRM -from configs.supported import algo_log_metrics, losses +from configs.supported import algo_log_metrics +from losses import initialize_loss def initialize_algorithm(config, datasets, train_grouper): train_dataset = datasets['train']['dataset'] @@ -11,23 +12,27 @@ def initialize_algorithm(config, datasets, train_grouper): # Configure the final layer of the networks used # The code below are defaults. Edit this if you need special config for your model. - if (train_dataset.is_classification) and (train_dataset.y_size == 1): - # For single-task classification, we have one output per class + if train_dataset.is_classification: + if train_dataset.y_size == 1: + # For single-task classification, we have one output per class + d_out = train_dataset.n_classes + elif train_dataset.y_size is None: + d_out = train_dataset.n_classes + elif (train_dataset.y_size > 1) and (train_dataset.n_classes == 2): + # For multi-task binary classification (each output is the logit for each binary class) + d_out = train_dataset.y_size + else: + raise RuntimeError('d_out not defined.') + elif train_dataset.is_detection: + # For detection, d_out is the number of classes d_out = train_dataset.n_classes - elif (train_dataset.is_classification) and (train_dataset.y_size is None): - d_out = train_dataset.n_classes - elif (train_dataset.is_classification) and (train_dataset.y_size > 1) and (train_dataset.n_classes == 2): - # For multi-task binary classification (each output is the logit for each binary class) - d_out = train_dataset.y_size - elif (not train_dataset.is_classification): + else: # For regression, we have one output per target dimension d_out = train_dataset.y_size - else: - raise RuntimeError('d_out not defined.') # Other config n_train_steps = len(train_loader) * config.n_epochs - loss = losses[config.loss_function] + loss = initialize_loss(config, d_out) metric = algo_log_metrics[config.algo_log_metric] if config.algorithm=='ERM': diff --git a/examples/algorithms/single_model_algorithm.py b/examples/algorithms/single_model_algorithm.py index e368b88f..1ee1ad5d 100644 --- a/examples/algorithms/single_model_algorithm.py +++ b/examples/algorithms/single_model_algorithm.py @@ -3,6 +3,7 @@ from scheduler import initialize_scheduler from optimizer import initialize_optimizer from torch.nn.utils import clip_grad_norm_ +from utils import move_to class SingleModelAlgorithm(GroupAlgorithm): """ @@ -47,9 +48,9 @@ def process_batch(self, batch): - y_true """ x, y_true, metadata = batch - x = x.to(self.device) - y_true = y_true.to(self.device) - g = self.grouper.metadata_to_group(metadata).to(self.device) + x = move_to(x, self.device) + y_true = move_to(y_true, self.device) + g = move_to(self.grouper.metadata_to_group(metadata), self.device) outputs = self.model(x) results = { diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index cd2d1d6f..f93c38d1 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -282,6 +282,29 @@ 'n_epochs': 4, 'process_outputs_function': None, }, + 'gwhd': { + 'split_scheme': 'official', + 'model': 'detr', + 'train_transform': 'image_base', + 'eval_transform': 'image_base', + 'model_kwargs': { + 'aux_loss': True, + 'n_queries': 150, + 'n_classes': 1}, + 'loss_function': 'detr_set_criterion', + 'groupby_fields': ['location'], + 'val_metric': 'dummy_all', # TODO + 'val_metric_decreasing': False, + 'algo_log_metric': None, # TODO + 'optimizer': 'Adam', + 'optimizer_kwargs': {}, + 'scheduler': None, + 'batch_size': 4, + 'lr': 1e-4, + 'weight_decay': 1e-4, + 'n_epochs': 10, + 'process_outputs_function': None, + } } ########################################## diff --git a/examples/configs/model.py b/examples/configs/model.py index 46714bbe..d2cb1cd4 100644 --- a/examples/configs/model.py +++ b/examples/configs/model.py @@ -15,19 +15,19 @@ 'scheduler': 'linear_schedule_with_warmup', }, 'densenet121': { - 'model_kwargs':{ + 'model_kwargs': { 'pretrained':True, }, 'target_resolution': (224, 224), }, 'wideresnet50': { - 'model_kwargs':{ + 'model_kwargs': { 'pretrained':True, }, 'target_resolution': (224, 224), }, 'resnet50': { - 'model_kwargs':{ + 'model_kwargs': { 'pretrained':True, }, 'target_resolution': (224, 224), @@ -37,4 +37,33 @@ 'target_resolution': (224, 224), }, 'logistic_regression': {}, + 'detr': { + 'max_grad_norm': 0.1, + 'model_kwargs': { + # Backbone. Always uses sine position embedding. + 'train_backbone': True, + 'backbone': 'resnet50', + 'dilation': False, + # Transformer + 'enc_layers': 6, + 'dec_layers': 6, + 'dim_feedforward': 2048, + 'hidden_dim': 256, + 'dropout': 0.1, + 'nheads': 8, + 'pre_norm': False, + }, + 'loss_kwargs': { + # Matcher + 'set_cost_class': 1, + 'set_cost_bbox': 5, + 'set_cost_giou': 2, + # Loss + 'mask_loss_coef': 1, + 'dice_loss_coef': 1, + 'bbox_loss_coef': 5, + 'giou_loss_coef': 2, + 'eos_coef': 0.1, + } + } } diff --git a/examples/configs/supported.py b/examples/configs/supported.py index 8b66b74e..a1d30fdb 100644 --- a/examples/configs/supported.py +++ b/examples/configs/supported.py @@ -1,18 +1,6 @@ -import torch.nn as nn -import torch -import sys, os - # metrics -from wilds.common.metrics.loss import ElementwiseLoss, Loss, MultiTaskLoss from wilds.common.metrics.all_metrics import Accuracy, MultiTaskAccuracy, MSE, multiclass_logits_to_pred, binary_logits_to_pred -losses = { - 'cross_entropy': ElementwiseLoss(loss_fn=nn.CrossEntropyLoss(reduction='none')), - 'lm_cross_entropy': MultiTaskLoss(loss_fn=nn.CrossEntropyLoss(reduction='none')), - 'mse': MSE(name='loss'), - 'multitask_bce': MultiTaskLoss(loss_fn=nn.BCEWithLogitsLoss(reduction='none')), -} - algo_log_metrics = { 'accuracy': Accuracy(prediction_fn=multiclass_logits_to_pred), 'mse': MSE(), @@ -27,11 +15,23 @@ None: None, } -# see initialize_*() functions for correspondence -transforms = ['bert', 'image_base', 'image_resize_and_center_crop', 'poverty_train'] +# See models/initializer.py models = ['resnet18_ms', 'resnet50', 'resnet34', 'wideresnet50', 'densenet121', 'bert-base-uncased', 'distilbert-base-uncased', - 'gin-virtual', 'logistic_regression', 'code-gpt-py'] + 'gin-virtual', 'logistic_regression', 'code-gpt-py', + 'detr'] + +# See algorithms/initializer.py algorithms = ['ERM', 'groupDRO', 'deepCORAL', 'IRM'] + +# See optimizer.py optimizers = ['SGD', 'Adam', 'AdamW'] + +# See scheduler.py schedulers = ['linear_schedule_with_warmup', 'ReduceLROnPlateau', 'StepLR'] + +# See transforms.py +transforms = ['bert', 'image_base', 'image_resize_and_center_crop', 'poverty_train'] + +# See losses.py +losses = ['cross_entropy', 'lm_cross_entropy', 'MSE', 'multitask_bce', 'detr_set_criterion'] diff --git a/examples/losses.py b/examples/losses.py new file mode 100644 index 00000000..9d8df5b7 --- /dev/null +++ b/examples/losses.py @@ -0,0 +1,47 @@ +import torch.nn as nn +from wilds.common.metrics.loss import ElementwiseLoss, Loss, MultiTaskLoss + +def initialize_loss(config, d_out): + if config.loss_function == 'cross_entropy': + return ElementwiseLoss(loss_fn=nn.CrossEntropyLoss(reduction='none')) + + elif config.loss_function == 'lm_cross_entropy': + return MultiTaskLoss(loss_fn=nn.CrossEntropyLoss(reduction='none')) + + elif config.loss_function == 'MSE': + return MSE(name='loss') + + elif config.loss_function == 'multitask_bce': + return MultiTaskLoss(loss_fn=nn.BCEWithLogitsLoss(reduction='none')) + + elif config.loss_function == 'detr_set_criterion': + return ElementwiseLoss(loss_fn=get_detr_set_criterion(config, d_out)) + + +def get_detr_set_criterion(config, d_out): + from examples.models.detr.matcher import HungarianMatcher + from examples.models.detr.detr import SetCriterion + + matcher = HungarianMatcher( + cost_class=config.loss_kwargs['set_cost_class'], + cost_bbox=config.loss_kwargs['set_cost_bbox'], + cost_giou=config.loss_kwargs['set_cost_giou']) + weight_dict = { + 'loss_ce': 1, + 'loss_bbox': config.loss_kwargs['bbox_loss_coef']} + weight_dict['loss_giou'] = config.loss_kwargs['giou_loss_coef'] + + if config.model_kwargs['aux_loss']: + aux_weight_dict = {} + for i in range(config.model_kwargs['dec_layers'] - 1): + aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()}) + weight_dict.update(aux_weight_dict) + + criterion = SetCriterion( + d_out, + matcher=matcher, + weight_dict=weight_dict, + eos_coef=config.loss_kwargs['eos_coef'], + losses=['labels', 'boxes', 'cardinality']).to(config.device) + + return criterion diff --git a/examples/models/detr/README.md b/examples/models/detr/README.md new file mode 100644 index 00000000..0be3336f --- /dev/null +++ b/examples/models/detr/README.md @@ -0,0 +1 @@ +DETR is licensed under the [Apache License 2.0](https://github.com/facebookresearch/detr/blob/master/LICENSE). Code is adapted from the [DETR GitHub repository](https://github.com/facebookresearch/detr/). diff --git a/examples/models/detr/__init__.py b/examples/models/detr/__init__.py new file mode 100644 index 00000000..a3f26531 --- /dev/null +++ b/examples/models/detr/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from .detr import build + + +def build_model(args): + return build(args) diff --git a/examples/models/detr/backbone.py b/examples/models/detr/backbone.py new file mode 100644 index 00000000..d03e8a5d --- /dev/null +++ b/examples/models/detr/backbone.py @@ -0,0 +1,119 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Backbone modules. +""" +from collections import OrderedDict + +import torch +import torch.nn.functional as F +import torchvision +from torch import nn +from torchvision.models._utils import IntermediateLayerGetter +from typing import Dict, List + +from .util.misc import NestedTensor, is_main_process + +from .position_encoding import build_position_encoding + + +class FrozenBatchNorm2d(torch.nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + Copy-paste from torchvision.misc.ops with added eps before rqsrt, + without which any other models than torchvision.models.resnet[18,34,50,101] + produce nans. + """ + + def __init__(self, n): + super(FrozenBatchNorm2d, self).__init__() + self.register_buffer("weight", torch.ones(n)) + self.register_buffer("bias", torch.zeros(n)) + self.register_buffer("running_mean", torch.zeros(n)) + self.register_buffer("running_var", torch.ones(n)) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + num_batches_tracked_key = prefix + 'num_batches_tracked' + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super(FrozenBatchNorm2d, self)._load_from_state_dict( + state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) + + def forward(self, x): + # move reshapes to the beginning + # to make it fuser-friendly + w = self.weight.reshape(1, -1, 1, 1) + b = self.bias.reshape(1, -1, 1, 1) + rv = self.running_var.reshape(1, -1, 1, 1) + rm = self.running_mean.reshape(1, -1, 1, 1) + eps = 1e-5 + scale = w * (rv + eps).rsqrt() + bias = b - rm * scale + return x * scale + bias + + +class BackboneBase(nn.Module): + + def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool): + super().__init__() + for name, parameter in backbone.named_parameters(): + if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: + parameter.requires_grad_(False) + if return_interm_layers: + return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} + else: + return_layers = {'layer4': "0"} + self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) + self.num_channels = num_channels + + def forward(self, tensor_list: NestedTensor): + xs = self.body(tensor_list.tensors) + out: Dict[str, NestedTensor] = {} + for name, x in xs.items(): + m = tensor_list.mask + assert m is not None + mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] + out[name] = NestedTensor(x, mask) + return out + + +class Backbone(BackboneBase): + """ResNet backbone with frozen BatchNorm.""" + def __init__(self, name: str, + train_backbone: bool, + return_interm_layers: bool, + dilation: bool): + backbone = getattr(torchvision.models, name)( + replace_stride_with_dilation=[False, False, dilation], + pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) + num_channels = 512 if name in ('resnet18', 'resnet34') else 2048 + super().__init__(backbone, train_backbone, num_channels, return_interm_layers) + + +class Joiner(nn.Sequential): + def __init__(self, backbone, position_embedding): + super().__init__(backbone, position_embedding) + + def forward(self, tensor_list: NestedTensor): + xs = self[0](tensor_list) + out: List[NestedTensor] = [] + pos = [] + for name, x in xs.items(): + out.append(x) + # position encoding + pos.append(self[1](x).to(x.tensors.dtype)) + + return out, pos + + +def build_backbone(args): + position_embedding = build_position_encoding(args) + train_backbone = args.lr_backbone > 0 + return_interm_layers = args.masks + backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation) + model = Joiner(backbone, position_embedding) + model.num_channels = backbone.num_channels + return model diff --git a/examples/models/detr/detr.py b/examples/models/detr/detr.py new file mode 100644 index 00000000..ef9e4e6a --- /dev/null +++ b/examples/models/detr/detr.py @@ -0,0 +1,338 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +DETR model and criterion classes. +""" +import torch +import torch.nn.functional as F +from torch import nn + +from .util import box_ops +from .util.misc import (NestedTensor, nested_tensor_from_tensor_list, + accuracy, get_world_size, interpolate, + is_dist_avail_and_initialized) + +from .backbone import build_backbone +from .matcher import build_matcher +from .transformer import build_transformer + + +class DETR(nn.Module): + """ This is the DETR module that performs object detection """ + def __init__(self, backbone, transformer, num_classes, num_queries, aux_loss=False): + """ Initializes the model. + Parameters: + backbone: torch module of the backbone to be used. See backbone.py + transformer: torch module of the transformer architecture. See transformer.py + num_classes: number of object classes + num_queries: number of object queries, ie detection slot. This is the maximal number of objects + DETR can detect in a single image. For COCO, we recommend 100 queries. + aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. + """ + super().__init__() + self.num_queries = num_queries + self.transformer = transformer + hidden_dim = transformer.d_model + self.class_embed = nn.Linear(hidden_dim, num_classes + 1) + self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) + self.query_embed = nn.Embedding(num_queries, hidden_dim) + self.input_proj = nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=1) + self.backbone = backbone + self.aux_loss = aux_loss + + def forward(self, samples: NestedTensor): + """ The forward expects a NestedTensor, which consists of: + - samples.tensor: batched images, of shape [batch_size x 3 x H x W] + - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels + + It returns a dict with the following elements: + - "pred_logits": the classification logits (including no-object) for all queries. + Shape= [batch_size x num_queries x (num_classes + 1)] + - "pred_boxes": The normalized boxes coordinates for all queries, represented as + (center_x, center_y, height, width). These values are normalized in [0, 1], + relative to the size of each individual image (disregarding possible padding). + See PostProcess for information on how to retrieve the unnormalized bounding box. + - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of + dictionnaries containing the two above keys for each decoder layer. + """ + if isinstance(samples, (list, torch.Tensor)): + samples = nested_tensor_from_tensor_list(samples) + features, pos = self.backbone(samples) + + src, mask = features[-1].decompose() + assert mask is not None + hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0] + + outputs_class = self.class_embed(hs) + outputs_coord = self.bbox_embed(hs).sigmoid() + out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]} + if self.aux_loss: + out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord) + return out + + @torch.jit.unused + def _set_aux_loss(self, outputs_class, outputs_coord): + # this is a workaround to make torchscript happy, as torchscript + # doesn't support dictionary with non-homogeneous values, such + # as a dict having both a Tensor and a list. + return [{'pred_logits': a, 'pred_boxes': b} + for a, b in zip(outputs_class[:-1], outputs_coord[:-1])] + + +class SetCriterion(nn.Module): + """ This class computes the loss for DETR. + The process happens in two steps: + 1) we compute hungarian assignment between ground truth boxes and the outputs of the model + 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) + """ + def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses): + """ Create the criterion. + Parameters: + num_classes: number of object categories, omitting the special no-object category + matcher: module able to compute a matching between targets and proposals + weight_dict: dict containing as key the names of the losses and as values their relative weight. + eos_coef: relative classification weight applied to the no-object category + losses: list of all the losses to be applied. See get_loss for list of available losses. + """ + super().__init__() + self.num_classes = num_classes + self.matcher = matcher + self.weight_dict = weight_dict + self.eos_coef = eos_coef + self.losses = losses + empty_weight = torch.ones(self.num_classes + 1) + empty_weight[-1] = self.eos_coef + self.register_buffer('empty_weight', empty_weight) + + def loss_labels(self, outputs, targets, indices, num_boxes, log=True): + """Classification loss (NLL) + targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] + """ + assert 'pred_logits' in outputs + src_logits = outputs['pred_logits'] + + idx = self._get_src_permutation_idx(indices) + target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) + target_classes = torch.full(src_logits.shape[:2], self.num_classes, + dtype=torch.int64, device=src_logits.device) + target_classes[idx] = target_classes_o + + # loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight) + loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight, reduction='none') + losses = {'loss_ce': loss_ce.mean(dim=1)} + + if log: + # TODO this should probably be a separate loss, not hacked in this one here + losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0] + return losses + + @torch.no_grad() + def loss_cardinality(self, outputs, targets, indices, num_boxes): + """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes + This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients + """ + pred_logits = outputs['pred_logits'] + device = pred_logits.device + tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device) + # Count the number of predictions that are NOT "no-object" (which is the last class) + card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1) + card_err = F.l1_loss(card_pred.float(), tgt_lengths.float()) + losses = {'cardinality_error': card_err} + return losses + + def loss_boxes(self, outputs, targets, indices, num_boxes): + """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss + targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] + The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size. + """ + assert 'pred_boxes' in outputs + idx = self._get_src_permutation_idx(indices) + src_boxes = outputs['pred_boxes'][idx] + target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) + + + loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none').sum(dim=1) + + loss_giou = 1 - torch.diag(box_ops.generalized_box_iou( + box_ops.box_cxcywh_to_xyxy(src_boxes), + box_ops.box_cxcywh_to_xyxy(target_boxes))) + + tgt_lengths = [len(v["labels"]) for v in targets] + + device = outputs['pred_logits'].device + losses = {} + losses['loss_bbox'] = torch.zeros(len(tgt_lengths), device=device) + losses['loss_giou'] = torch.zeros(len(tgt_lengths), device=device) + + pos = 0 + for i, tgt_length in enumerate(tgt_lengths): + losses['loss_bbox'][i] = loss_bbox[pos:pos+tgt_length].mean() + losses['loss_giou'][i] = loss_giou[pos:pos+tgt_length].mean() + pos += tgt_length + + # losses['loss_bbox'] = loss_bbox.sum() / num_boxes + # losses['loss_giou'] = loss_giou.sum() / num_boxes + return losses + + def _get_src_permutation_idx(self, indices): + # permute predictions following indices + batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) + src_idx = torch.cat([src for (src, _) in indices]) + return batch_idx, src_idx + + def _get_tgt_permutation_idx(self, indices): + # permute targets following indices + batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) + tgt_idx = torch.cat([tgt for (_, tgt) in indices]) + return batch_idx, tgt_idx + + def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs): + loss_map = { + 'labels': self.loss_labels, + 'cardinality': self.loss_cardinality, + 'boxes': self.loss_boxes, + } + assert loss in loss_map, f'do you really want to compute {loss} loss?' + return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs) + + def forward(self, outputs, targets): + """ This performs the loss computation. + Parameters: + outputs: dict of tensors, see the output specification of the model for the format + targets: list of dicts, such that len(targets) == batch_size. + The expected keys in each dict depends on the losses applied, see each loss' doc + """ + outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'} + + # Retrieve the matching between the outputs of the last layer and the targets + indices = self.matcher(outputs_without_aux, targets) + + # Compute the average number of target boxes accross all nodes, for normalization purposes + num_boxes = sum(len(t["labels"]) for t in targets) + num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) + if is_dist_avail_and_initialized(): + torch.distributed.all_reduce(num_boxes) + num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() + + # Compute all the requested losses + total_loss = 0 + losses = {} + for loss in self.losses: + losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes)) + + # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. + if 'aux_outputs' in outputs: + for i, aux_outputs in enumerate(outputs['aux_outputs']): + indices = self.matcher(aux_outputs, targets) + for loss in self.losses: + kwargs = {} + if loss == 'labels': + # Logging is enabled only for the last layer + kwargs = {'log': False} + l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs) + l_dict = {k + f'_{i}': v for k, v in l_dict.items()} + losses.update(l_dict) + + # Sum up weighted losses by element + + device = outputs['pred_logits'].device + elementwise_loss = torch.zeros(len(outputs['pred_logits']), device=device) + + for k in self.weight_dict: + elementwise_loss += self.weight_dict[k] * losses[k] + + return elementwise_loss + + +class PostProcess(nn.Module): + """ This module converts the model's output into the format expected by the coco api""" + @torch.no_grad() + def forward(self, outputs, target_sizes): + """ Perform the computation + Parameters: + outputs: raw outputs of the model + target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch + For evaluation, this must be the original image size (before any data augmentation) + For visualization, this should be the image size after data augment, but before padding + """ + out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes'] + + assert len(out_logits) == len(target_sizes) + assert target_sizes.shape[1] == 2 + + prob = F.softmax(out_logits, -1) + scores, labels = prob[..., :-1].max(-1) + + # convert to [x0, y0, x1, y1] format + boxes = box_ops.box_cxcywh_to_xyxy(out_bbox) + # and from relative [0, 1] to absolute [0, height] coordinates + img_h, img_w = target_sizes.unbind(1) + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) + boxes = boxes * scale_fct[:, None, :] + + results = [{'scores': s, 'labels': l, 'boxes': b} for s, l, b in zip(scores, labels, boxes)] + + return results + + +class MLP(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +def build(args): + # the `num_classes` naming here is somewhat misleading. + # it indeed corresponds to `max_obj_id + 1`, where max_obj_id + # is the maximum id for a class in your dataset. For example, + # COCO has a max_obj_id of 90, so we pass `num_classes` to be 91. + # As another example, for a dataset that has a single class with id 1, + # you should pass `num_classes` to be 2 (max_obj_id + 1). + # For more details on this, check the following discussion + # https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223 + num_classes = 20 if args.dataset_file != 'coco' else 91 + if args.dataset_file == "coco_panoptic": + # for panoptic, we just add a num_classes that is large enough to hold + # max_obj_id + 1, but the exact value doesn't really matter + num_classes = 250 + device = torch.device(args.device) + + backbone = build_backbone(args) + + transformer = build_transformer(args) + + model = DETR( + backbone, + transformer, + num_classes=num_classes, + num_queries=args.num_queries, + aux_loss=args.aux_loss, + ) + + matcher = build_matcher(args) + weight_dict = {'loss_ce': 1, 'loss_bbox': args.bbox_loss_coef} + weight_dict['loss_giou'] = args.giou_loss_coef + + # TODO this is a hack + if args.aux_loss: + aux_weight_dict = {} + for i in range(args.dec_layers - 1): + aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()}) + weight_dict.update(aux_weight_dict) + + losses = ['labels', 'boxes', 'cardinality'] + + criterion = SetCriterion(num_classes, matcher=matcher, weight_dict=weight_dict, + eos_coef=args.eos_coef, losses=losses) + criterion.to(device) + postprocessors = {'bbox': PostProcess()} + + return model, criterion, postprocessors diff --git a/examples/models/detr/matcher.py b/examples/models/detr/matcher.py new file mode 100644 index 00000000..48f1177a --- /dev/null +++ b/examples/models/detr/matcher.py @@ -0,0 +1,86 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Modules to compute the matching cost and solve the corresponding LSAP. +""" +import torch +from scipy.optimize import linear_sum_assignment +from torch import nn + +from .util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou + + +class HungarianMatcher(nn.Module): + """This class computes an assignment between the targets and the predictions of the network + + For efficiency reasons, the targets don't include the no_object. Because of this, in general, + there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, + while the others are un-matched (and thus treated as non-objects). + """ + + def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1): + """Creates the matcher + + Params: + cost_class: This is the relative weight of the classification error in the matching cost + cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost + cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost + """ + super().__init__() + self.cost_class = cost_class + self.cost_bbox = cost_bbox + self.cost_giou = cost_giou + assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0" + + @torch.no_grad() + def forward(self, outputs, targets): + """ Performs the matching + + Params: + outputs: This is a dict that contains at least these entries: + "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits + "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates + + targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: + "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth + objects in the target) containing the class labels + "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates + + Returns: + A list of size batch_size, containing tuples of (index_i, index_j) where: + - index_i is the indices of the selected predictions (in order) + - index_j is the indices of the corresponding selected targets (in order) + For each batch element, it holds: + len(index_i) = len(index_j) = min(num_queries, num_target_boxes) + """ + bs, num_queries = outputs["pred_logits"].shape[:2] + + # We flatten to compute the cost matrices in a batch + out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes] + out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] + + # Also concat the target labels and boxes + tgt_ids = torch.cat([v["labels"] for v in targets]) + tgt_bbox = torch.cat([v["boxes"] for v in targets]) + + # Compute the classification cost. Contrary to the loss, we don't use the NLL, + # but approximate it in 1 - proba[target class]. + # The 1 is a constant that doesn't change the matching, it can be ommitted. + cost_class = -out_prob[:, tgt_ids] + + # Compute the L1 cost between boxes + cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) + + # Compute the giou cost betwen boxes + cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)) + + # Final cost matrix + C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou + C = C.view(bs, num_queries, -1).cpu() + + sizes = [len(v["boxes"]) for v in targets] + indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] + return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] + + +def build_matcher(args): + return HungarianMatcher(cost_class=args.set_cost_class, cost_bbox=args.set_cost_bbox, cost_giou=args.set_cost_giou) diff --git a/examples/models/detr/position_encoding.py b/examples/models/detr/position_encoding.py new file mode 100644 index 00000000..bc7d9eb4 --- /dev/null +++ b/examples/models/detr/position_encoding.py @@ -0,0 +1,89 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Various positional encodings for the transformer. +""" +import math +import torch +from torch import nn + +from .util.misc import NestedTensor + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, tensor_list: NestedTensor): + x = tensor_list.tensors + mask = tensor_list.mask + assert mask is not None + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + +class PositionEmbeddingLearned(nn.Module): + """ + Absolute pos embedding, learned. + """ + def __init__(self, num_pos_feats=256): + super().__init__() + self.row_embed = nn.Embedding(50, num_pos_feats) + self.col_embed = nn.Embedding(50, num_pos_feats) + self.reset_parameters() + + def reset_parameters(self): + nn.init.uniform_(self.row_embed.weight) + nn.init.uniform_(self.col_embed.weight) + + def forward(self, tensor_list: NestedTensor): + x = tensor_list.tensors + h, w = x.shape[-2:] + i = torch.arange(w, device=x.device) + j = torch.arange(h, device=x.device) + x_emb = self.col_embed(i) + y_emb = self.row_embed(j) + pos = torch.cat([ + x_emb.unsqueeze(0).repeat(h, 1, 1), + y_emb.unsqueeze(1).repeat(1, w, 1), + ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) + return pos + + +def build_position_encoding(args): + N_steps = args.hidden_dim // 2 + if args.position_embedding in ('v2', 'sine'): + # TODO find a better way of exposing other arguments + position_embedding = PositionEmbeddingSine(N_steps, normalize=True) + elif args.position_embedding in ('v3', 'learned'): + position_embedding = PositionEmbeddingLearned(N_steps) + else: + raise ValueError(f"not supported {args.position_embedding}") + + return position_embedding diff --git a/examples/models/detr/transformer.py b/examples/models/detr/transformer.py new file mode 100644 index 00000000..714c84df --- /dev/null +++ b/examples/models/detr/transformer.py @@ -0,0 +1,296 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +DETR Transformer class. + +Copy-paste from torch.nn.Transformer with modifications: + * positional encodings are passed in MHattention + * extra LN at the end of encoder is removed + * decoder returns a stack of activations from all decoding layers +""" +import copy +from typing import Optional, List + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + + +class Transformer(nn.Module): + + def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, + num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False, + return_intermediate_dec=False): + super().__init__() + + encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, + dropout, activation, normalize_before) + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) + + decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, + dropout, activation, normalize_before) + decoder_norm = nn.LayerNorm(d_model) + self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, + return_intermediate=return_intermediate_dec) + + self._reset_parameters() + + self.d_model = d_model + self.nhead = nhead + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, src, mask, query_embed, pos_embed): + # flatten NxCxHxW to HWxNxC + bs, c, h, w = src.shape + src = src.flatten(2).permute(2, 0, 1) + pos_embed = pos_embed.flatten(2).permute(2, 0, 1) + query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) + mask = mask.flatten(1) + + tgt = torch.zeros_like(query_embed) + memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) + hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, + pos=pos_embed, query_pos=query_embed) + return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w) + + +class TransformerEncoder(nn.Module): + + def __init__(self, encoder_layer, num_layers, norm=None): + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward(self, src, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + output = src + + for layer in self.layers: + output = layer(output, src_mask=mask, + src_key_padding_mask=src_key_padding_mask, pos=pos) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class TransformerDecoder(nn.Module): + + def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + self.return_intermediate = return_intermediate + + def forward(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + output = tgt + + intermediate = [] + + for layer in self.layers: + output = layer(output, memory, tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos=pos, query_pos=query_pos) + if self.return_intermediate: + intermediate.append(self.norm(output)) + + if self.norm is not None: + output = self.norm(output) + if self.return_intermediate: + intermediate.pop() + intermediate.append(output) + + if self.return_intermediate: + return torch.stack(intermediate) + + return output.unsqueeze(0) + + +class TransformerEncoderLayer(nn.Module): + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + q = k = self.with_pos_embed(src, pos) + src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, + key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + return src + + def forward_pre(self, src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + src2 = self.norm1(src) + q = k = self.with_pos_embed(src2, pos) + src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, + key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src2 = self.norm2(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) + src = src + self.dropout2(src2) + return src + + def forward(self, src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + if self.normalize_before: + return self.forward_pre(src, src_mask, src_key_padding_mask, pos) + return self.forward_post(src, src_mask, src_key_padding_mask, pos) + + +class TransformerDecoderLayer(nn.Module): + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + q = k = self.with_pos_embed(tgt, query_pos) + tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + return tgt + + def forward_pre(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + tgt2 = self.norm1(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt2 = self.norm2(tgt) + tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout2(tgt2) + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt + + def forward(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + if self.normalize_before: + return self.forward_pre(tgt, memory, tgt_mask, memory_mask, + tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) + return self.forward_post(tgt, memory, tgt_mask, memory_mask, + tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def build_transformer(args): + return Transformer( + d_model=args.hidden_dim, + dropout=args.dropout, + nhead=args.nheads, + dim_feedforward=args.dim_feedforward, + num_encoder_layers=args.enc_layers, + num_decoder_layers=args.dec_layers, + normalize_before=args.pre_norm, + return_intermediate_dec=True, + ) + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(F"activation should be relu/gelu, not {activation}.") diff --git a/examples/models/detr/util/__init__.py b/examples/models/detr/util/__init__.py new file mode 100644 index 00000000..168f9979 --- /dev/null +++ b/examples/models/detr/util/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved diff --git a/examples/models/detr/util/box_ops.py b/examples/models/detr/util/box_ops.py new file mode 100644 index 00000000..9c088e5b --- /dev/null +++ b/examples/models/detr/util/box_ops.py @@ -0,0 +1,88 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Utilities for bounding box manipulation and GIoU. +""" +import torch +from torchvision.ops.boxes import box_area + + +def box_cxcywh_to_xyxy(x): + x_c, y_c, w, h = x.unbind(-1) + b = [(x_c - 0.5 * w), (y_c - 0.5 * h), + (x_c + 0.5 * w), (y_c + 0.5 * h)] + return torch.stack(b, dim=-1) + + +def box_xyxy_to_cxcywh(x): + x0, y0, x1, y1 = x.unbind(-1) + b = [(x0 + x1) / 2, (y0 + y1) / 2, + (x1 - x0), (y1 - y0)] + return torch.stack(b, dim=-1) + + +# modified from torchvision to also return the union +def box_iou(boxes1, boxes2): + area1 = box_area(boxes1) + area2 = box_area(boxes2) + + lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] + rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] + + wh = (rb - lt).clamp(min=0) # [N,M,2] + inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] + + union = area1[:, None] + area2 - inter + + iou = inter / union + return iou, union + + +def generalized_box_iou(boxes1, boxes2): + """ + Generalized IoU from https://giou.stanford.edu/ + + The boxes should be in [x0, y0, x1, y1] format + + Returns a [N, M] pairwise matrix, where N = len(boxes1) + and M = len(boxes2) + """ + # degenerate boxes gives inf / nan results + # so do an early check + assert (boxes1[:, 2:] >= boxes1[:, :2]).all() + assert (boxes2[:, 2:] >= boxes2[:, :2]).all() + iou, union = box_iou(boxes1, boxes2) + + lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) + rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) + + wh = (rb - lt).clamp(min=0) # [N,M,2] + area = wh[:, :, 0] * wh[:, :, 1] + + return iou - (area - union) / area + + +def masks_to_boxes(masks): + """Compute the bounding boxes around the provided masks + + The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. + + Returns a [N, 4] tensors, with the boxes in xyxy format + """ + if masks.numel() == 0: + return torch.zeros((0, 4), device=masks.device) + + h, w = masks.shape[-2:] + + y = torch.arange(0, h, dtype=torch.float) + x = torch.arange(0, w, dtype=torch.float) + y, x = torch.meshgrid(y, x) + + x_mask = (masks * x.unsqueeze(0)) + x_max = x_mask.flatten(1).max(-1)[0] + x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] + + y_mask = (masks * y.unsqueeze(0)) + y_max = y_mask.flatten(1).max(-1)[0] + y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] + + return torch.stack([x_min, y_min, x_max, y_max], 1) diff --git a/examples/models/detr/util/misc.py b/examples/models/detr/util/misc.py new file mode 100644 index 00000000..1d4e5eb1 --- /dev/null +++ b/examples/models/detr/util/misc.py @@ -0,0 +1,467 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Misc functions, including distributed helpers. + +Mostly copy-paste from torchvision references. +""" +import os +import subprocess +import time +from collections import defaultdict, deque +import datetime +import pickle +from typing import Optional, List + +import torch +import torch.distributed as dist +from torch import Tensor + +# needed due to empty tensor bug in pytorch and torchvision 0.5 +import torchvision +if float(torchvision.__version__[:3]) < 0.7: + from torchvision.ops import _new_empty_tensor + from torchvision.ops.misc import _output_size + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value) + + +def all_gather(data): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors) + Args: + data: any picklable object + Returns: + list[data]: list of data gathered from each rank + """ + world_size = get_world_size() + if world_size == 1: + return [data] + + # serialized to a Tensor + buffer = pickle.dumps(data) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to("cuda") + + # obtain Tensor size of each rank + local_size = torch.tensor([tensor.numel()], device="cuda") + size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] + dist.all_gather(size_list, local_size) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + + # receiving Tensor from all ranks + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + tensor_list = [] + for _ in size_list: + tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) + if local_size != max_size: + padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") + tensor = torch.cat((tensor, padding), dim=0) + dist.all_gather(tensor_list, tensor) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def reduce_dict(input_dict, average=True): + """ + Args: + input_dict (dict): all the values will be reduced + average (bool): whether to do average or sum + Reduce the values in the dictionary from all processes so that all processes + have the averaged results. Returns a dict with the same fields as + input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.all_reduce(values) + if average: + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append( + "{}: {}".format(name, str(meter)) + ) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = '' + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt='{avg:.4f}') + data_time = SmoothedValue(fmt='{avg:.4f}') + space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + if torch.cuda.is_available(): + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}', + 'max mem: {memory:.0f}' + ]) + else: + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}' + ]) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB)) + else: + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time))) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('{} Total time: {} ({:.4f} s / it)'.format( + header, total_time_str, total_time / len(iterable))) + + +def get_sha(): + cwd = os.path.dirname(os.path.abspath(__file__)) + + def _run(command): + return subprocess.check_output(command, cwd=cwd).decode('ascii').strip() + sha = 'N/A' + diff = "clean" + branch = 'N/A' + try: + sha = _run(['git', 'rev-parse', 'HEAD']) + subprocess.check_output(['git', 'diff'], cwd=cwd) + diff = _run(['git', 'diff-index', 'HEAD']) + diff = "has uncommited changes" if diff else "clean" + branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) + except Exception: + pass + message = f"sha: {sha}, status: {diff}, branch: {branch}" + return message + + +def collate_fn(batch): + batch = list(zip(*batch)) + batch[0] = nested_tensor_from_tensor_list(batch[0]) + return tuple(batch) + + +def _max_by_axis(the_list): + # type: (List[List[int]]) -> List[int] + maxes = the_list[0] + for sublist in the_list[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + +class NestedTensor(object): + def __init__(self, tensors, mask: Optional[Tensor]): + self.tensors = tensors + self.mask = mask + + def to(self, device): + # type: (Device) -> NestedTensor # noqa + cast_tensor = self.tensors.to(device) + mask = self.mask + if mask is not None: + assert mask is not None + cast_mask = mask.to(device) + else: + cast_mask = None + return NestedTensor(cast_tensor, cast_mask) + + def decompose(self): + return self.tensors, self.mask + + def __repr__(self): + return str(self.tensors) + + +def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): + # TODO make this more general + if tensor_list[0].ndim == 3: + if torchvision._is_tracing(): + # nested_tensor_from_tensor_list() does not export well to ONNX + # call _onnx_nested_tensor_from_tensor_list() instead + return _onnx_nested_tensor_from_tensor_list(tensor_list) + + # TODO make it support different-sized images + max_size = _max_by_axis([list(img.shape) for img in tensor_list]) + # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) + batch_shape = [len(tensor_list)] + max_size + b, c, h, w = batch_shape + dtype = tensor_list[0].dtype + device = tensor_list[0].device + tensor = torch.zeros(batch_shape, dtype=dtype, device=device) + mask = torch.ones((b, h, w), dtype=torch.bool, device=device) + for img, pad_img, m in zip(tensor_list, tensor, mask): + pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + m[: img.shape[1], :img.shape[2]] = False + else: + raise ValueError('not supported') + return NestedTensor(tensor, mask) + + +# _onnx_nested_tensor_from_tensor_list() is an implementation of +# nested_tensor_from_tensor_list() that is supported by ONNX tracing. +@torch.jit.unused +def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: + max_size = [] + for i in range(tensor_list[0].dim()): + max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64) + max_size.append(max_size_i) + max_size = tuple(max_size) + + # work around for + # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) + # m[: img.shape[1], :img.shape[2]] = False + # which is not yet supported in onnx + padded_imgs = [] + padded_masks = [] + for img in tensor_list: + padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] + padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) + padded_imgs.append(padded_img) + + m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) + padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) + padded_masks.append(padded_mask.to(torch.bool)) + + tensor = torch.stack(padded_imgs) + mask = torch.stack(padded_masks) + + return NestedTensor(tensor, mask=mask) + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ['WORLD_SIZE']) + args.gpu = int(os.environ['LOCAL_RANK']) + elif 'SLURM_PROCID' in os.environ: + args.rank = int(os.environ['SLURM_PROCID']) + args.gpu = args.rank % torch.cuda.device_count() + else: + print('Not using distributed mode') + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = 'nccl' + print('| distributed init (rank {}): {}'.format( + args.rank, args.dist_url), flush=True) + torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) + + +@torch.no_grad() +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k""" + if target.numel() == 0: + return [torch.zeros([], device=output.device)] + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): + # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor + """ + Equivalent to nn.functional.interpolate, but with support for empty batch sizes. + This will eventually be supported natively by PyTorch, and this + class can go away. + """ + if float(torchvision.__version__[:3]) < 0.7: + if input.numel() > 0: + return torch.nn.functional.interpolate( + input, size, scale_factor, mode, align_corners + ) + + output_shape = _output_size(2, input, size, scale_factor) + output_shape = list(input.shape[:-2]) + list(output_shape) + return _new_empty_tensor(input, output_shape) + else: + return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) diff --git a/examples/models/detr/util/plot_utils.py b/examples/models/detr/util/plot_utils.py new file mode 100644 index 00000000..0f24bed0 --- /dev/null +++ b/examples/models/detr/util/plot_utils.py @@ -0,0 +1,107 @@ +""" +Plotting utilities to visualize training logs. +""" +import torch +import pandas as pd +import numpy as np +import seaborn as sns +import matplotlib.pyplot as plt + +from pathlib import Path, PurePath + + +def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'): + ''' + Function to plot specific fields from training log(s). Plots both training and test results. + + :: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file + - fields = which results to plot from each log file - plots both training and test for each field. + - ewm_col = optional, which column to use as the exponential weighted smoothing of the plots + - log_name = optional, name of log file if different than default 'log.txt'. + + :: Outputs - matplotlib plots of results in fields, color coded for each log file. + - solid lines are training results, dashed lines are test results. + + ''' + func_name = "plot_utils.py::plot_logs" + + # verify logs is a list of Paths (list[Paths]) or single Pathlib object Path, + # convert single Path to list to avoid 'not iterable' error + + if not isinstance(logs, list): + if isinstance(logs, PurePath): + logs = [logs] + print(f"{func_name} info: logs param expects a list argument, converted to list[Path].") + else: + raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \ + Expect list[Path] or single Path obj, received {type(logs)}") + + # Quality checks - verify valid dir(s), that every item in list is Path object, and that log_name exists in each dir + for i, dir in enumerate(logs): + if not isinstance(dir, PurePath): + raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}") + if not dir.exists(): + raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}") + # verify log_name exists + fn = Path(dir / log_name) + if not fn.exists(): + print(f"-> missing {log_name}. Have you gotten to Epoch 1 in training?") + print(f"--> full path of missing log file: {fn}") + return + + # load log file(s) and plot + dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs] + + fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5)) + + for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))): + for j, field in enumerate(fields): + if field == 'mAP': + coco_eval = pd.DataFrame( + np.stack(df.test_coco_eval_bbox.dropna().values)[:, 1] + ).ewm(com=ewm_col).mean() + axs[j].plot(coco_eval, c=color) + else: + df.interpolate().ewm(com=ewm_col).mean().plot( + y=[f'train_{field}', f'test_{field}'], + ax=axs[j], + color=[color] * 2, + style=['-', '--'] + ) + for ax, field in zip(axs, fields): + ax.legend([Path(p).name for p in logs]) + ax.set_title(field) + + +def plot_precision_recall(files, naming_scheme='iter'): + if naming_scheme == 'exp_id': + # name becomes exp_id + names = [f.parts[-3] for f in files] + elif naming_scheme == 'iter': + names = [f.stem for f in files] + else: + raise ValueError(f'not supported {naming_scheme}') + fig, axs = plt.subplots(ncols=2, figsize=(16, 5)) + for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names): + data = torch.load(f) + # precision is n_iou, n_points, n_cat, n_area, max_det + precision = data['precision'] + recall = data['params'].recThrs + scores = data['scores'] + # take precision for all classes, all areas and 100 detections + precision = precision[0, :, :, 0, -1].mean(1) + scores = scores[0, :, :, 0, -1].mean(1) + prec = precision.mean() + rec = data['recall'][0, :, 0, -1].mean() + print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' + + f'score={scores.mean():0.3f}, ' + + f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}' + ) + axs[0].plot(recall, precision, c=color) + axs[1].plot(recall, scores, c=color) + + axs[0].set_title('Precision / Recall') + axs[0].legend(names) + axs[1].set_title('Scores / Recall') + axs[1].legend(names) + return fig, axs diff --git a/examples/models/initializer.py b/examples/models/initializer.py index 4d414763..86f0b83d 100644 --- a/examples/models/initializer.py +++ b/examples/models/initializer.py @@ -1,12 +1,6 @@ import torch.nn as nn -import torchvision -from models.bert.bert import BertClassifier, BertFeaturizer -from models.bert.distilbert import DistilBertClassifier, DistilBertFeaturizer -from models.resnet_multispectral import ResNet18 + from models.layers import Identity -from models.gnn import GINVirtual -from models.code_gpt import GPT2LMHeadLogit, GPT2FeaturizerLMHeadLogit -from transformers import GPT2Tokenizer def initialize_model(config, d_out, is_featurizer=False): """ @@ -36,6 +30,7 @@ def initialize_model(config, d_out, is_featurizer=False): name=config.model, d_out=d_out, **config.model_kwargs) + elif 'bert' in config.model: if is_featurizer: featurizer = initialize_bert_based_model(config, d_out, is_featurizer) @@ -43,21 +38,28 @@ def initialize_model(config, d_out, is_featurizer=False): model = (featurizer, classifier) else: model = initialize_bert_based_model(config, d_out) + elif config.model == 'resnet18_ms': # multispectral resnet 18 + from models.resnet_multispectral import ResNet18 if is_featurizer: featurizer = ResNet18(num_classes=None, **config.model_kwargs) classifier = nn.Linear(featurizer.d_out, d_out) model = (featurizer, classifier) else: model = ResNet18(num_classes=d_out, **config.model_kwargs) + elif config.model == 'gin-virtual': + from models.gnn import GINVirtual if is_featurizer: featurizer = GINVirtual(num_tasks=None, **config.model_kwargs) classifier = nn.Linear(featurizer.d_out, d_out) model = (featurizer, classifier) else: model = GINVirtual(num_tasks=d_out, **config.model_kwargs) + elif config.model == 'code-gpt-py': + from models.code_gpt import GPT2LMHeadLogit, GPT2FeaturizerLMHeadLogit + from transformers import GPT2Tokenizer name = 'microsoft/CodeGPT-small-py' tokenizer = GPT2Tokenizer.from_pretrained(name) if is_featurizer: @@ -69,14 +71,27 @@ def initialize_model(config, d_out, is_featurizer=False): else: model = GPT2LMHeadLogit.from_pretrained(name) model.resize_token_embeddings(len(tokenizer)) + elif config.model == 'logistic_regression': assert not is_featurizer, "Featurizer not supported for logistic regression" model = nn.Linear(out_features=d_out, **config.model_kwargs) + + elif config.model == 'detr': + if is_featurizer: # TODO + raise NotImplementedError('Featurizer not implemented for detection yet') + else: + model = initialize_detr_model(config, d_out) + else: raise ValueError(f'Model: {config.model} not recognized.') + return model + def initialize_bert_based_model(config, d_out, is_featurizer=False): + from models.bert.bert import BertClassifier, BertFeaturizer + from models.bert.distilbert import DistilBertClassifier, DistilBertFeaturizer + if config.model == 'bert-base-uncased': if is_featurizer: model = BertFeaturizer.from_pretrained(config.model, **config.model_kwargs) @@ -98,6 +113,8 @@ def initialize_bert_based_model(config, d_out, is_featurizer=False): return model def initialize_torchvision_model(name, d_out, **kwargs): + import torchvision + # get constructor and last layer names if name == 'wideresnet50': constructor_name = 'wide_resnet50_2' @@ -123,3 +140,44 @@ def initialize_torchvision_model(name, d_out, **kwargs): model.d_out = d_out setattr(model, last_layer_name, last_layer) return model + +def initialize_detr_model(config, d_out): + from models.detr.backbone import Backbone, Joiner + from models.detr.position_encoding import PositionEmbeddingSine + from models.detr.transformer import Transformer + from models.detr.detr import DETR + + position_embedding = PositionEmbeddingSine( + config.model_kwargs['hidden_dim'] // 2, + normalize=True) + + backbone = Backbone( + name=config.model_kwargs['backbone'], + train_backbone=config.model_kwargs['train_backbone'], + return_interm_layers=False, # No segmentation + dilation=config.model_kwargs['dilation']) + num_channels = backbone.num_channels + backbone = Joiner(backbone, position_embedding) + backbone.num_channels = num_channels + + transformer = Transformer( + d_model=config.model_kwargs['hidden_dim'], + dropout=config.model_kwargs['dropout'], + nhead=config.model_kwargs['nheads'], + dim_feedforward=config.model_kwargs['dim_feedforward'], + num_encoder_layers=config.model_kwargs['enc_layers'], + num_decoder_layers=config.model_kwargs['dec_layers'], + normalize_before=config.model_kwargs['pre_norm'], + return_intermediate_dec=True, + ) + + # Update these with dataset configs + model = DETR( + backbone, + transformer, + num_classes=d_out, + num_queries=config.model_kwargs['n_queries'], + aux_loss=config.model_kwargs['aux_loss'], + ) + + return model diff --git a/examples/run_expt.py b/examples/run_expt.py index 173603ab..fe0a8da1 100644 --- a/examples/run_expt.py +++ b/examples/run_expt.py @@ -61,6 +61,8 @@ def main(): # Objective parser.add_argument('--loss_function', choices = supported.losses) + parser.add_argument('--loss_kwargs', nargs='*', action=ParseKwargs, default={}, + help='keyword arguments for loss initialization passed as key1=value1 key2=value2') # Algorithm parser.add_argument('--groupby_fields', nargs='+') diff --git a/examples/utils.py b/examples/utils.py index 89780d62..dbda67d8 100644 --- a/examples/utils.py +++ b/examples/utils.py @@ -203,3 +203,14 @@ def get_model_prefix(dataset, config): config.log_dir, f"{dataset_name}_{replicate_str}_") return prefix + +# Adapted from https://discuss.pytorch.org/t/pytorch-tensor-to-device-for-a-list-of-dict/66283 +def move_to(obj, device): + if torch.is_tensor(obj): + return obj.to(device) + elif isinstance(obj, dict): + return {k: move_to(v, device) for k, v in obj.items()} + elif isinstance(obj, list): + return [move_to(v, device) for v in obj] + else: + raise TypeError("Invalid type for move_to") diff --git a/wilds/common/metrics/all_metrics.py b/wilds/common/metrics/all_metrics.py index 0f5d7eb1..02a7cb35 100644 --- a/wilds/common/metrics/all_metrics.py +++ b/wilds/common/metrics/all_metrics.py @@ -19,7 +19,7 @@ def binary_logits_to_score(logits): def multiclass_logits_to_pred(logits): """ - Takes multi-class logits of size (batch_size, ..., n_classes) and returns predictions + Takes multi-class logits of size (batch_size, ..., n_classes) and returns predictions by taking an argmax at the last dimension """ assert logits.dim() > 1 @@ -142,3 +142,16 @@ def _compute(self, y_pred, y_true): def worst(self, metrics): return minimum(metrics) + +class DummyMetric(Metric): + def __init__(self, prediction_fn=None, name=None): + self.prediction_fn = prediction_fn + if name is None: + name = 'dummy' + super().__init__(name=name) + + def _compute(self, y_pred, y_true): + return -1 + + def worst(self, metrics): + return minimum(metrics) diff --git a/wilds/common/metrics/metric.py b/wilds/common/metrics/metric.py index 9c4372b0..3f42a5af 100644 --- a/wilds/common/metrics/metric.py +++ b/wilds/common/metrics/metric.py @@ -1,5 +1,5 @@ import numpy as np -from wilds.common.utils import avg_over_groups, get_counts +from wilds.common.utils import avg_over_groups, get_counts, numel import torch class Metric: @@ -82,7 +82,7 @@ def compute(self, y_pred, y_true, return_dict=True): Output (return_dict=True): - results (dict): Dictionary of results, mapping metric.agg_metric_field to avg_metric """ - if y_true.numel()==0: + if numel(y_true) == 0: agg_metric = torch.tensor(0., device=y_true.device) else: agg_metric = self._compute(y_pred, y_true) diff --git a/wilds/common/utils.py b/wilds/common/utils.py index 7854393a..fc7f438b 100644 --- a/wilds/common/utils.py +++ b/wilds/common/utils.py @@ -82,7 +82,6 @@ def avg_over_groups(v, g, n_groups): group_counts (Tensor) """ assert v.device==g.device - device = v.device assert v.numel()==g.numel() group_count = get_counts(g, n_groups) group_avgs = torch_scatter.scatter(src=v, index=g, dim_size=n_groups, reduce='mean') @@ -126,3 +125,11 @@ def threshold_at_recall(y_pred, y_true, global_recall=60): """ Calculate the model threshold to use to achieve a desired global_recall level. Assumes that y_true is a vector of the true binary labels.""" return np.percentile(y_pred[y_true == 1], 100-global_recall) + +def numel(obj): + if torch.is_tensor(obj): + return obj.numel() + elif isinstance(obj, list): + return len(obj) + else: + raise TypeError("Invalid type for numel") diff --git a/wilds/datasets/gwhd_dataset.py b/wilds/datasets/gwhd_dataset.py index 064be8ff..957ee374 100644 --- a/wilds/datasets/gwhd_dataset.py +++ b/wilds/datasets/gwhd_dataset.py @@ -3,10 +3,20 @@ import torch from pathlib import Path from PIL import Image -#from wilds.common.metrics.all_metrics import MultiTaskAccuracy from wilds.datasets.wilds_dataset import WILDSDataset +from wilds.common.grouper import CombinatorialGrouper +from wilds.common.metrics.all_metrics import DummyMetric +def _collate_fn(batch): + """ + Stack x (batch[0]) and metadata (batch[2]), but not y. + """ + batch = list(zip(*batch)) + batch[0] = torch.stack(batch[0]) + batch[2] = torch.stack(batch[2]) + return tuple(batch) + class GWHDDataset(WILDSDataset): """ The GWHD-wilds wheat head localization dataset. @@ -23,7 +33,7 @@ class GWHDDataset(WILDSDataset): Input (x): 1024x1024 RGB images of wheat field canopy between flowering and ripening. Output (y): - y is a nx4-dimensional vector where each line represents a box coordinate (top-x,top-y,height,width) + y is a nx4-dimensional vector where each line represents a box coordinate (x_min,y_min,x_max,y_max) Metadata: Each image is annotated with the ID of the domain it came from (integer from 0 to 10). Website: @@ -58,6 +68,9 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' self._data_dir = self.initialize_data_dir(root_dir, download) self._original_resolution = (1024, 1024) self.root = Path(self.data_dir) + self._is_detection = True + self._y_size = 1 + self._n_classes = 1 self._split_scheme = split_scheme @@ -91,29 +104,44 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' labels = list(df['labels'].values) self._split_array.extend([i] * len(labels)) - - - labels = [{"boxes": torch.stack([ torch.tensor([int(i) for i in box.split(" ")]) for box in boxes.split(";")]) ,"labels": torch.tensor([1.]*len(list(boxes.split(";")))).long() } if type(boxes) != float else {"boxes":torch.empty(0,4),"labels":torch.empty(0,1,dtype=torch.long)} for boxes in labels] + labels = [{ + "boxes": torch.stack([ + torch.tensor([int(i) for i in box.split(" ")]) + for box in boxes.split(";") + ]), + "labels": torch.tensor([1.]*len(list(boxes.split(";")))).long() + } if type(boxes) != float else { + "boxes": torch.empty(0,4), + "labels": torch.empty(0,1,dtype=torch.long) + } for boxes in labels] + # TODO: Figure out empty images + + # The above boxes are (x_min,y_min,x_max,y_max) + # Convert labels into (center_x, center_y, w, h) normalized, which is what DETR expects + # TODO: If it's not standard, we can probably put this in a transform somewhere + for label in labels: + boxes = label['boxes'] + center_x = (boxes[:, 0] + boxes[:, 2]) / 2 / self._original_resolution[0] + center_y = (boxes[:, 1] + boxes[:, 3]) / 2 / self._original_resolution[1] + width = (boxes[:, 2] - boxes[:, 0]) / self._original_resolution[0] + height = (boxes[:, 3] - boxes[:, 1]) / self._original_resolution[1] + label['boxes'] = torch.stack((center_x, center_y, width, height), dim=1) self._y_array.extend(labels) - - self._metadata_array.extend(list(df['group'].values)) - - self._y_size = 1 - - self._metadata_fields = ["domain"] - self._split_array = np.array(self._split_array) - - - + self._metadata_fields = ['location'] self._metadata_array = torch.tensor(self._metadata_array, dtype=torch.long).unsqueeze(1) - #self._metric = MultiTaskAccuracy() + self._eval_grouper = CombinatorialGrouper( + dataset=self, + groupby_fields=['location']) + + self._metric = DummyMetric() # TODO + self._collate = _collate_fn def get_input(self, idx): """ diff --git a/wilds/datasets/wilds_dataset.py b/wilds/datasets/wilds_dataset.py index 1f8bf21a..b650e556 100644 --- a/wilds/datasets/wilds_dataset.py +++ b/wilds/datasets/wilds_dataset.py @@ -106,6 +106,10 @@ def check_init(self): # Check metadata assert len(self.metadata_array.shape) == 2 assert len(self.metadata_fields) == self.metadata_array.shape[1] + + # Check that it is not both classification and detection + assert not (self.is_classification and self.is_detection) + # For convenience, include y in metadata_fields if y_size == 1 if self.y_size == 1: assert 'y' in self.metadata_fields @@ -242,10 +246,16 @@ def n_classes(self): def is_classification(self): """ Boolean. True if the task is classification, and false otherwise. - Used for logging purposes. """ return (self.n_classes is not None) + @property + def is_detection(self): + """ + Boolean. True if the task is detection, and false otherwise. + """ + return getattr(self, '_is_detection', False) + @property def metadata_fields(self): """ From 70546644f433ecf5ca437cec1b065681a882d8e0 Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Sat, 13 Mar 2021 21:03:22 -0800 Subject: [PATCH 032/244] working version --- examples/algorithms/algorithm.py | 25 ++++++---- examples/algorithms/single_model_algorithm.py | 6 ++- examples/configs/datasets.py | 7 +-- examples/configs/supported.py | 2 + examples/models/detr/detr.py | 9 ++-- examples/models/initializer.py | 16 +++++++ examples/run_expt.py | 8 +++- examples/train.py | 29 ++++++++---- examples/utils.py | 47 +++++++++++++++++++ wilds/common/metrics/all_metrics.py | 10 +++- wilds/datasets/gwhd_dataset.py | 13 ++++- 11 files changed, 142 insertions(+), 30 deletions(-) diff --git a/examples/algorithms/algorithm.py b/examples/algorithms/algorithm.py index c93d960a..8eb0d47e 100644 --- a/examples/algorithms/algorithm.py +++ b/examples/algorithms/algorithm.py @@ -1,5 +1,6 @@ import torch import torch.nn as nn +from utils import move_to, detach_and_clone class Algorithm(nn.Module): def __init__(self, device): @@ -93,19 +94,23 @@ def sanitize_dict(self, in_dict, to_out_device=True): Helper function that sanitizes dictionaries by: - moving to the specified output device - removing any gradient information - - turning any Tensor of size 1 to a simple number + - detaching and cloning the tensors Args: - in_dict (dictionary) Output: - out_dict (dictionary): sanitized version of in_dict """ - out_dict = {} - for k, v in in_dict.items(): - if isinstance(v, torch.Tensor): - v_out = v.detach().clone() - if to_out_device: - v_out = v_out.to(self.out_device) - else: - v_out = v - out_dict[k] = v_out + out_dict = detach_and_clone(in_dict) + if to_out_device: + out_dict = move_to(out_dict, self.out_device) + # + # out_dict = {} + # for k, v in in_dict.items(): + # if isinstance(v, torch.Tensor): + # v_out = v.detach().clone() + # if to_out_device: + # v_out = v_out.to(self.out_device) + # else: + # v_out = v + # out_dict[k] = v_out return out_dict diff --git a/examples/algorithms/single_model_algorithm.py b/examples/algorithms/single_model_algorithm.py index 1ee1ad5d..bd6845ad 100644 --- a/examples/algorithms/single_model_algorithm.py +++ b/examples/algorithms/single_model_algorithm.py @@ -81,7 +81,11 @@ def evaluate(self, batch): assert not self.is_training results = self.process_batch(batch) results['objective'] = self.objective(results).item() - self.update_log(results) + try: + self.update_log(results) + except: + import IPython + IPython.embed() return self.sanitize_dict(results) def update(self, batch): diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index f93c38d1..462a1106 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -289,8 +289,9 @@ 'eval_transform': 'image_base', 'model_kwargs': { 'aux_loss': True, - 'n_queries': 150, - 'n_classes': 1}, + 'n_queries': 200, + 'n_classes': 1, + 'pretrained': True}, 'loss_function': 'detr_set_criterion', 'groupby_fields': ['location'], 'val_metric': 'dummy_all', # TODO @@ -303,7 +304,7 @@ 'lr': 1e-4, 'weight_decay': 1e-4, 'n_epochs': 10, - 'process_outputs_function': None, + 'process_outputs_function': 'remove_detr_aux_outputs', } } diff --git a/examples/configs/supported.py b/examples/configs/supported.py index a1d30fdb..32fac62f 100644 --- a/examples/configs/supported.py +++ b/examples/configs/supported.py @@ -1,5 +1,6 @@ # metrics from wilds.common.metrics.all_metrics import Accuracy, MultiTaskAccuracy, MSE, multiclass_logits_to_pred, binary_logits_to_pred +from utils import remove_key algo_log_metrics = { 'accuracy': Accuracy(prediction_fn=multiclass_logits_to_pred), @@ -12,6 +13,7 @@ process_outputs_functions = { 'binary_logits_to_pred': binary_logits_to_pred, 'multiclass_logits_to_pred': multiclass_logits_to_pred, + 'remove_detr_aux_outputs': remove_key('aux_outputs'), None: None, } diff --git a/examples/models/detr/detr.py b/examples/models/detr/detr.py index ef9e4e6a..1bf5397b 100644 --- a/examples/models/detr/detr.py +++ b/examples/models/detr/detr.py @@ -149,7 +149,6 @@ def loss_boxes(self, outputs, targets, indices, num_boxes): src_boxes = outputs['pred_boxes'][idx] target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) - loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none').sum(dim=1) loss_giou = 1 - torch.diag(box_ops.generalized_box_iou( @@ -165,8 +164,12 @@ def loss_boxes(self, outputs, targets, indices, num_boxes): pos = 0 for i, tgt_length in enumerate(tgt_lengths): - losses['loss_bbox'][i] = loss_bbox[pos:pos+tgt_length].mean() - losses['loss_giou'][i] = loss_giou[pos:pos+tgt_length].mean() + if tgt_length == 0: + losses['loss_bbox'][i] = 0 + losses['loss_giou'][i] = 0 + else: + losses['loss_bbox'][i] = loss_bbox[pos:pos+tgt_length].mean() + losses['loss_giou'][i] = loss_giou[pos:pos+tgt_length].mean() pos += tgt_length # losses['loss_bbox'] = loss_bbox.sum() / num_boxes diff --git a/examples/models/initializer.py b/examples/models/initializer.py index 86f0b83d..9ec47b62 100644 --- a/examples/models/initializer.py +++ b/examples/models/initializer.py @@ -1,3 +1,4 @@ +import torch import torch.nn as nn from models.layers import Identity @@ -142,6 +143,7 @@ def initialize_torchvision_model(name, d_out, **kwargs): return model def initialize_detr_model(config, d_out): + from models.detr.backbone import Backbone, Joiner from models.detr.position_encoding import PositionEmbeddingSine from models.detr.transformer import Transformer @@ -180,4 +182,18 @@ def initialize_detr_model(config, d_out): aux_loss=config.model_kwargs['aux_loss'], ) + if config.model_kwargs['pretrained']: + # Calling torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True, num_classes=d_out) does not work + # due to a ModuleNotFoundError. Perhaps some configuration error there. + # So we have to do it manually. + + checkpoint = torch.hub.load_state_dict_from_url( + url='https://dl.fbaipublicfiles.com/detr/detr-r50-e632da11.pth', + map_location='cpu', + check_hash=True) + del checkpoint["model"]["query_embed.weight"] + del checkpoint["model"]["class_embed.weight"] + del checkpoint["model"]["class_embed.bias"] + model.load_state_dict(checkpoint["model"], strict=False) + return model diff --git a/examples/run_expt.py b/examples/run_expt.py index fe0a8da1..8b44a4da 100644 --- a/examples/run_expt.py +++ b/examples/run_expt.py @@ -19,8 +19,12 @@ from configs.utils import populate_defaults import configs.supported as supported +import torch.multiprocessing + def main(): - ''' set default hyperparams in default_hyperparams.py ''' + torch.multiprocessing.set_sharing_strategy('file_system') + + ''' to set default hyperparams for each dataset/model, look at configs/ ''' parser = argparse.ArgumentParser() # Required arguments @@ -62,7 +66,7 @@ def main(): # Objective parser.add_argument('--loss_function', choices = supported.losses) parser.add_argument('--loss_kwargs', nargs='*', action=ParseKwargs, default={}, - help='keyword arguments for loss initialization passed as key1=value1 key2=value2') + help='keyword arguments for loss initialization passed as key1=value1 key2=value2') # Algorithm parser.add_argument('--groupby_fields', nargs='+') diff --git a/examples/train.py b/examples/train.py index 774f3d1e..fa27e2e3 100644 --- a/examples/train.py +++ b/examples/train.py @@ -1,7 +1,7 @@ import os from tqdm import tqdm import torch -from utils import save_model, save_pred, get_pred_prefix, get_model_prefix +from utils import save_model, save_pred, get_pred_prefix, get_model_prefix, detach_and_clone, collate_list import torch.autograd.profiler as profiler from configs.supported import process_outputs_functions @@ -26,6 +26,9 @@ def run_epoch(algorithm, dataset, general_logger, epoch, config, train): batch_idx = 0 iterator = tqdm(dataset['loader']) if config.progress_bar else dataset['loader'] + # import psutil + # process = psutil.Process(os.getpid()) + for batch in iterator: if train: batch_results = algorithm.update(batch) @@ -34,23 +37,33 @@ def run_epoch(algorithm, dataset, general_logger, epoch, config, train): # These tensors are already detached, but we need to clone them again # Otherwise they don't get garbage collected properly in some versions - # The subsequent detach is just for safety + # The extra detach is just for safety # (they should already be detached in batch_results) - epoch_y_true.append(batch_results['y_true'].clone().detach()) - y_pred = batch_results['y_pred'].clone().detach() + epoch_y_true.append(detach_and_clone(batch_results['y_true'])) + y_pred = detach_and_clone(batch_results['y_pred']) if config.process_outputs_function is not None: y_pred = process_outputs_functions[config.process_outputs_function](y_pred) epoch_y_pred.append(y_pred) - epoch_metadata.append(batch_results['metadata'].clone().detach()) + epoch_metadata.append(detach_and_clone(batch_results['metadata'])) if train and (batch_idx+1) % config.log_every==0: log_results(algorithm, dataset, general_logger, epoch, batch_idx) + # t = torch.cuda.get_device_properties(0).total_memory + # r = torch.cuda.memory_reserved(0) + # a = torch.cuda.memory_allocated(0) + # f = r-a # free inside reserved + # print(f'Total: {f:10} Reserved: {r:10} Allocated: {a:10} Free: {f:10}') + # + # mem = process.memory_info().rss + # print(f'Mem: {mem / 1024 / 1024:6.1f}M') batch_idx += 1 - epoch_y_pred = torch.cat(epoch_y_pred) - epoch_y_true = torch.cat(epoch_y_true) - epoch_metadata = torch.cat(epoch_metadata) + + epoch_y_pred = collate_list(epoch_y_pred) + epoch_y_true = collate_list(epoch_y_true) + epoch_metadata = collate_list(epoch_metadata) + results, results_str = dataset['dataset'].eval( epoch_y_pred, epoch_y_true, diff --git a/examples/utils.py b/examples/utils.py index dbda67d8..5a18d030 100644 --- a/examples/utils.py +++ b/examples/utils.py @@ -1,6 +1,7 @@ import sys import os import csv +import copy import argparse import random from pathlib import Path @@ -212,5 +213,51 @@ def move_to(obj, device): return {k: move_to(v, device) for k, v in obj.items()} elif isinstance(obj, list): return [move_to(v, device) for v in obj] + elif isinstance(obj, float) or isinstance(obj, int): + return obj else: raise TypeError("Invalid type for move_to") + +def detach_and_clone(obj): + if torch.is_tensor(obj): + return obj.detach().clone() + elif isinstance(obj, dict): + obj = copy.copy(obj) + return {k: detach_and_clone(v) for k, v in obj.items()} + elif isinstance(obj, list): + obj = copy.copy(obj) + return [detach_and_clone(v) for v in obj] + elif isinstance(obj, float) or isinstance(obj, int): + return obj + else: + raise TypeError("Invalid type for detach_and_clone") + +def collate_list(vec): + """ + If vec is a list of Tensors, it concatenates them all along the first dimension. + + If vec is a list of lists, it joins these lists together, but does not attempt to recursively collate. This allows each element of the list to be, e.g., its own dict. + + If vec is a list of dicts (with the same keys in each dict), it returns a single dict with the same keys. For each key, it recursively collates all entries in the list. + """ + if not isinstance(vec, list): + raise TypeError("collate_list must take in a list") + elem = vec[0] + if torch.is_tensor(elem): + return torch.cat(vec) + elif isinstance(elem, list): + return [obj for sublist in vec for obj in sublist] + elif isinstance(elem, dict): + return {k: collate_list([d[k] for d in vec]) for k in elem} + else: + raise TypeError("Elements of the list to collate must be tensors or dicts.") + +def remove_key(key): + """ + Returns a function that strips out a key from a dict. + """ + def remove(d): + if not isinstance(d, dict): + raise TypeError("remove_key must take in a dict") + return {k: v for (k,v) in d.items() if k != key} + return remove diff --git a/wilds/common/metrics/all_metrics.py b/wilds/common/metrics/all_metrics.py index 02a7cb35..d3c78592 100644 --- a/wilds/common/metrics/all_metrics.py +++ b/wilds/common/metrics/all_metrics.py @@ -4,7 +4,7 @@ import torch.nn.functional as F from wilds.common.metrics.metric import Metric, ElementwiseMetric, MultiTaskMetric from wilds.common.metrics.loss import ElementwiseLoss -from wilds.common.utils import avg_over_groups, minimum, maximum +from wilds.common.utils import avg_over_groups, minimum, maximum, get_counts import sklearn.metrics from scipy.stats import pearsonr @@ -151,7 +151,13 @@ def __init__(self, prediction_fn=None, name=None): super().__init__(name=name) def _compute(self, y_pred, y_true): - return -1 + return torch.tensor(-1) + + def _compute_group_wise(self, y_pred, y_true, g, n_groups): + group_metrics = torch.ones(n_groups, device=g.device) * -1 + group_counts = get_counts(g, n_groups) + worst_group_metric = self.worst(group_metrics) + return group_metrics, group_counts, worst_group_metric def worst(self, metrics): return minimum(metrics) diff --git a/wilds/datasets/gwhd_dataset.py b/wilds/datasets/gwhd_dataset.py index 957ee374..ad1a7979 100644 --- a/wilds/datasets/gwhd_dataset.py +++ b/wilds/datasets/gwhd_dataset.py @@ -112,7 +112,8 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' "labels": torch.tensor([1.]*len(list(boxes.split(";")))).long() } if type(boxes) != float else { "boxes": torch.empty(0,4), - "labels": torch.empty(0,1,dtype=torch.long) + # "labels": torch.empty(0,1,dtype=torch.long) + "labels": torch.empty(0,dtype=torch.long) } for boxes in labels] # TODO: Figure out empty images @@ -127,6 +128,9 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' height = (boxes[:, 3] - boxes[:, 1]) / self._original_resolution[1] label['boxes'] = torch.stack((center_x, center_y, width, height), dim=1) + # num_boxes = [len(example['boxes']) for example in labels] + # print(f'Max num_boxes is {max(num_boxes)}') + self._y_array.extend(labels) self._metadata_array.extend(list(df['group'].values)) @@ -149,6 +153,13 @@ def get_input(self, idx): """ img_filename = self.root / "images" / self._image_array[idx] x = Image.open(img_filename) + # + # import psutil + # for proc in psutil.process_iter(): + # try: + # print(proc.open_files()) + # except (psutil.AccessDenied): + # pass return x def eval(self, y_pred, y_true, metadata): From 8bee0647389fc1ec54683b3dc29cb67b590e116b Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Sat, 13 Mar 2021 23:10:43 -0800 Subject: [PATCH 033/244] letter change --- examples/run_expt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/run_expt.py b/examples/run_expt.py index 8b44a4da..7df658f2 100644 --- a/examples/run_expt.py +++ b/examples/run_expt.py @@ -24,7 +24,7 @@ def main(): torch.multiprocessing.set_sharing_strategy('file_system') - ''' to set default hyperparams for each dataset/model, look at configs/ ''' + ''' to see default hyperparams for each dataset/model, look at configs/ ''' parser = argparse.ArgumentParser() # Required arguments From 2e7a0e1aa9d9327ad022bc4b00e2d04b47b8e5aa Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Sun, 14 Mar 2021 11:39:22 -0700 Subject: [PATCH 034/244] save preds --- examples/configs/datasets.py | 4 ++++ examples/models/detr/detr.py | 1 - examples/run_expt.py | 5 ++++- examples/train.py | 29 +++++++++++++++-------------- examples/utils.py | 16 ++++++++++------ 5 files changed, 33 insertions(+), 22 deletions(-) diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index 462a1106..99b251ce 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -305,6 +305,10 @@ 'weight_decay': 1e-4, 'n_epochs': 10, 'process_outputs_function': 'remove_detr_aux_outputs', + 'loader_kwargs': { + 'num_workers': 1, + 'pin_memory': True, + }, } } diff --git a/examples/models/detr/detr.py b/examples/models/detr/detr.py index 1bf5397b..c7135bb9 100644 --- a/examples/models/detr/detr.py +++ b/examples/models/detr/detr.py @@ -236,7 +236,6 @@ def forward(self, outputs, targets): losses.update(l_dict) # Sum up weighted losses by element - device = outputs['pred_logits'].device elementwise_loss = torch.zeros(len(outputs['pred_logits']), device=device) diff --git a/examples/run_expt.py b/examples/run_expt.py index 7df658f2..e8039113 100644 --- a/examples/run_expt.py +++ b/examples/run_expt.py @@ -275,12 +275,15 @@ def main(): epoch = best_epoch else: epoch = config.eval_epoch + if epoch == best_epoch: + is_best = True evaluate( algorithm=algorithm, datasets=datasets, epoch=epoch, general_logger=logger, - config=config) + config=config, + is_best=is_best) logger.close() for split in datasets: diff --git a/examples/train.py b/examples/train.py index fa27e2e3..22028ded 100644 --- a/examples/train.py +++ b/examples/train.py @@ -59,11 +59,9 @@ def run_epoch(algorithm, dataset, general_logger, epoch, config, train): batch_idx += 1 - epoch_y_pred = collate_list(epoch_y_pred) epoch_y_true = collate_list(epoch_y_true) epoch_metadata = collate_list(epoch_metadata) - results, results_str = dataset['dataset'].eval( epoch_y_pred, epoch_y_true, @@ -125,7 +123,7 @@ def train(algorithm, datasets, general_logger, config, epoch_offset, best_val_me general_logger.write('\n') -def evaluate(algorithm, datasets, epoch, general_logger, config): +def evaluate(algorithm, datasets, epoch, general_logger, config, is_best): algorithm.eval() for split, dataset in datasets.items(): if (not config.evaluate_all_splits) and (split not in config.eval_splits): @@ -136,17 +134,20 @@ def evaluate(algorithm, datasets, epoch, general_logger, config): iterator = tqdm(dataset['loader']) if config.progress_bar else dataset['loader'] for batch in iterator: batch_results = algorithm.evaluate(batch) - epoch_y_true.append(batch_results['y_true'].clone().detach()) - y_pred = batch_results['y_pred'].clone().detach() + epoch_y_true.append(detach_and_clone(batch_results['y_true'])) + y_pred = detach_and_clone(batch_results['y_pred']) if config.process_outputs_function is not None: y_pred = process_outputs_functions[config.process_outputs_function](y_pred) epoch_y_pred.append(y_pred) - epoch_metadata.append(batch_results['metadata'].clone().detach()) + epoch_metadata.append(detach_and_clone(batch_results['metadata'])) + epoch_y_pred = collate_list(epoch_y_pred) + epoch_y_true = collate_list(epoch_y_true) + epoch_metadata = collate_list(epoch_metadata) results, results_str = dataset['dataset'].eval( - torch.cat(epoch_y_pred), - torch.cat(epoch_y_true), - torch.cat(epoch_metadata)) + epoch_y_pred, + epoch_y_true, + epoch_metadata) results['epoch'] = epoch dataset['eval_logger'].log(results) @@ -155,7 +156,7 @@ def evaluate(algorithm, datasets, epoch, general_logger, config): # Skip saving train preds, since the train loader generally shuffles the data if split != 'train': - save_pred_if_needed(y_pred, dataset, epoch, config, is_best=False, force_save=True) + save_pred_if_needed(epoch_y_pred, dataset, epoch, config, is_best, force_save=True) def log_results(algorithm, dataset, general_logger, epoch, batch_idx): @@ -173,11 +174,11 @@ def save_pred_if_needed(y_pred, dataset, epoch, config, is_best, force_save=Fals if config.save_pred: prefix = get_pred_prefix(dataset, config) if force_save or (config.save_step is not None and (epoch + 1) % config.save_step == 0): - save_pred(y_pred, prefix + f'epoch:{epoch}_pred.csv') - if config.save_last: - save_pred(y_pred, prefix + f'epoch:last_pred.csv') + save_pred(y_pred, prefix + f'epoch:{epoch}_pred') + if (not force_save) and config.save_last: + save_pred(y_pred, prefix + f'epoch:last_pred') if config.save_best and is_best: - save_pred(y_pred, prefix + f'epoch:best_pred.csv') + save_pred(y_pred, prefix + f'epoch:best_pred') def save_model_if_needed(algorithm, dataset, epoch, config, is_best, best_val_metric): diff --git a/examples/utils.py b/examples/utils.py index 5a18d030..b4b1aa69 100644 --- a/examples/utils.py +++ b/examples/utils.py @@ -1,7 +1,6 @@ import sys import os import csv -import copy import argparse import random from pathlib import Path @@ -177,9 +176,16 @@ def initialize_wandb(config): project=f"wilds") wandb.config.update(config) -def save_pred(y_pred, csv_path): - df = pd.DataFrame(y_pred.numpy()) - df.to_csv(csv_path, index=False, header=False) +def save_pred(y_pred, path_prefix): + # Single tensor + if torch.is_tensor(y_pred): + df = pd.DataFrame(y_pred.numpy()) + df.to_csv(path_prefix + '.csv', index=False, header=False) + # Dictionary + elif isinstance(y_pred, dict): + torch.save(y_pred, path_prefix + '.pth') + else: + raise TypeError("Invalid type for save_pred") def get_replicate_str(dataset, config): if dataset['dataset'].dataset_name == 'poverty': @@ -222,10 +228,8 @@ def detach_and_clone(obj): if torch.is_tensor(obj): return obj.detach().clone() elif isinstance(obj, dict): - obj = copy.copy(obj) return {k: detach_and_clone(v) for k, v in obj.items()} elif isinstance(obj, list): - obj = copy.copy(obj) return [detach_and_clone(v) for v in obj] elif isinstance(obj, float) or isinstance(obj, int): return obj From 79972ebe084f9708420f6991800f9572afa346c9 Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Mon, 15 Mar 2021 11:16:02 -0700 Subject: [PATCH 035/244] comment cleanups --- examples/algorithms/algorithm.py | 12 +----------- examples/algorithms/single_model_algorithm.py | 8 ++------ examples/losses.py | 2 ++ examples/models/initializer.py | 2 -- 4 files changed, 5 insertions(+), 19 deletions(-) diff --git a/examples/algorithms/algorithm.py b/examples/algorithms/algorithm.py index 8eb0d47e..5c734766 100644 --- a/examples/algorithms/algorithm.py +++ b/examples/algorithms/algorithm.py @@ -94,7 +94,7 @@ def sanitize_dict(self, in_dict, to_out_device=True): Helper function that sanitizes dictionaries by: - moving to the specified output device - removing any gradient information - - detaching and cloning the tensors + - detaching and cloning the tensors Args: - in_dict (dictionary) Output: @@ -103,14 +103,4 @@ def sanitize_dict(self, in_dict, to_out_device=True): out_dict = detach_and_clone(in_dict) if to_out_device: out_dict = move_to(out_dict, self.out_device) - # - # out_dict = {} - # for k, v in in_dict.items(): - # if isinstance(v, torch.Tensor): - # v_out = v.detach().clone() - # if to_out_device: - # v_out = v_out.to(self.out_device) - # else: - # v_out = v - # out_dict[k] = v_out return out_dict diff --git a/examples/algorithms/single_model_algorithm.py b/examples/algorithms/single_model_algorithm.py index bd6845ad..c5d1071b 100644 --- a/examples/algorithms/single_model_algorithm.py +++ b/examples/algorithms/single_model_algorithm.py @@ -80,12 +80,8 @@ def evaluate(self, batch): """ assert not self.is_training results = self.process_batch(batch) - results['objective'] = self.objective(results).item() - try: - self.update_log(results) - except: - import IPython - IPython.embed() + results['objective'] = self.objective(results).item() + self.update_log(results) return self.sanitize_dict(results) def update(self, batch): diff --git a/examples/losses.py b/examples/losses.py index 9d8df5b7..ff48e6da 100644 --- a/examples/losses.py +++ b/examples/losses.py @@ -17,6 +17,8 @@ def initialize_loss(config, d_out): elif config.loss_function == 'detr_set_criterion': return ElementwiseLoss(loss_fn=get_detr_set_criterion(config, d_out)) + else: + raise ValueError(f'config.loss_function {config.loss_function} not recognized') def get_detr_set_criterion(config, d_out): from examples.models.detr.matcher import HungarianMatcher diff --git a/examples/models/initializer.py b/examples/models/initializer.py index 9ec47b62..0c6387a8 100644 --- a/examples/models/initializer.py +++ b/examples/models/initializer.py @@ -173,7 +173,6 @@ def initialize_detr_model(config, d_out): return_intermediate_dec=True, ) - # Update these with dataset configs model = DETR( backbone, transformer, @@ -186,7 +185,6 @@ def initialize_detr_model(config, d_out): # Calling torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True, num_classes=d_out) does not work # due to a ModuleNotFoundError. Perhaps some configuration error there. # So we have to do it manually. - checkpoint = torch.hub.load_state_dict_from_url( url='https://dl.fbaipublicfiles.com/detr/detr-r50-e632da11.pth', map_location='cpu', From 5802cf23c9e33839ae462bfa9d550d2bdcb93d14 Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Tue, 16 Mar 2021 13:26:47 -0700 Subject: [PATCH 036/244] label bug fix and added check_init --- examples/losses.py | 2 +- examples/models/detr/__init__.py | 6 +- examples/models/detr/detr.py | 157 ++++++++++++++++--------------- examples/utils.py | 2 +- wilds/common/metrics/metric.py | 5 +- wilds/common/utils.py | 16 +++- wilds/datasets/gwhd_dataset.py | 18 ++-- wilds/datasets/wilds_dataset.py | 6 +- 8 files changed, 114 insertions(+), 98 deletions(-) diff --git a/examples/losses.py b/examples/losses.py index ff48e6da..3e5cd933 100644 --- a/examples/losses.py +++ b/examples/losses.py @@ -15,7 +15,7 @@ def initialize_loss(config, d_out): return MultiTaskLoss(loss_fn=nn.BCEWithLogitsLoss(reduction='none')) elif config.loss_function == 'detr_set_criterion': - return ElementwiseLoss(loss_fn=get_detr_set_criterion(config, d_out)) + return ElementwiseLoss(loss_fn=get_detr_set_criterion(config, d_out)) else: raise ValueError(f'config.loss_function {config.loss_function} not recognized') diff --git a/examples/models/detr/__init__.py b/examples/models/detr/__init__.py index a3f26531..435bab61 100644 --- a/examples/models/detr/__init__.py +++ b/examples/models/detr/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved -from .detr import build +# from .detr import build -def build_model(args): - return build(args) +# def build_model(args): + # return build(args) diff --git a/examples/models/detr/detr.py b/examples/models/detr/detr.py index c7135bb9..b734b4a7 100644 --- a/examples/models/detr/detr.py +++ b/examples/models/detr/detr.py @@ -245,35 +245,6 @@ def forward(self, outputs, targets): return elementwise_loss -class PostProcess(nn.Module): - """ This module converts the model's output into the format expected by the coco api""" - @torch.no_grad() - def forward(self, outputs, target_sizes): - """ Perform the computation - Parameters: - outputs: raw outputs of the model - target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch - For evaluation, this must be the original image size (before any data augmentation) - For visualization, this should be the image size after data augment, but before padding - """ - out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes'] - - assert len(out_logits) == len(target_sizes) - assert target_sizes.shape[1] == 2 - - prob = F.softmax(out_logits, -1) - scores, labels = prob[..., :-1].max(-1) - - # convert to [x0, y0, x1, y1] format - boxes = box_ops.box_cxcywh_to_xyxy(out_bbox) - # and from relative [0, 1] to absolute [0, height] coordinates - img_h, img_w = target_sizes.unbind(1) - scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) - boxes = boxes * scale_fct[:, None, :] - - results = [{'scores': s, 'labels': l, 'boxes': b} for s, l, b in zip(scores, labels, boxes)] - - return results class MLP(nn.Module): @@ -291,50 +262,84 @@ def forward(self, x): return x -def build(args): - # the `num_classes` naming here is somewhat misleading. - # it indeed corresponds to `max_obj_id + 1`, where max_obj_id - # is the maximum id for a class in your dataset. For example, - # COCO has a max_obj_id of 90, so we pass `num_classes` to be 91. - # As another example, for a dataset that has a single class with id 1, - # you should pass `num_classes` to be 2 (max_obj_id + 1). - # For more details on this, check the following discussion - # https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223 - num_classes = 20 if args.dataset_file != 'coco' else 91 - if args.dataset_file == "coco_panoptic": - # for panoptic, we just add a num_classes that is large enough to hold - # max_obj_id + 1, but the exact value doesn't really matter - num_classes = 250 - device = torch.device(args.device) - - backbone = build_backbone(args) - - transformer = build_transformer(args) - - model = DETR( - backbone, - transformer, - num_classes=num_classes, - num_queries=args.num_queries, - aux_loss=args.aux_loss, - ) - - matcher = build_matcher(args) - weight_dict = {'loss_ce': 1, 'loss_bbox': args.bbox_loss_coef} - weight_dict['loss_giou'] = args.giou_loss_coef - - # TODO this is a hack - if args.aux_loss: - aux_weight_dict = {} - for i in range(args.dec_layers - 1): - aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()}) - weight_dict.update(aux_weight_dict) - - losses = ['labels', 'boxes', 'cardinality'] - - criterion = SetCriterion(num_classes, matcher=matcher, weight_dict=weight_dict, - eos_coef=args.eos_coef, losses=losses) - criterion.to(device) - postprocessors = {'bbox': PostProcess()} - - return model, criterion, postprocessors +# +# class PostProcess(nn.Module): +# """ This module converts the model's output into the format expected by the coco api""" +# @torch.no_grad() +# def forward(self, outputs, target_sizes): +# """ Perform the computation +# Parameters: +# outputs: raw outputs of the model +# target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch +# For evaluation, this must be the original image size (before any data augmentation) +# For visualization, this should be the image size after data augment, but before padding +# """ +# out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes'] +# +# assert len(out_logits) == len(target_sizes) +# assert target_sizes.shape[1] == 2 +# +# prob = F.softmax(out_logits, -1) +# scores, labels = prob[..., :-1].max(-1) +# +# # convert to [x0, y0, x1, y1] format +# boxes = box_ops.box_cxcywh_to_xyxy(out_bbox) +# # and from relative [0, 1] to absolute [0, height] coordinates +# img_h, img_w = target_sizes.unbind(1) +# scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) +# boxes = boxes * scale_fct[:, None, :] +# +# results = [{'scores': s, 'labels': l, 'boxes': b} for s, l, b in zip(scores, labels, boxes)] +# +# return results +# +# +# +# +# def build(args): +# # the `num_classes` naming here is somewhat misleading. +# # it indeed corresponds to `max_obj_id + 1`, where max_obj_id +# # is the maximum id for a class in your dataset. For example, +# # COCO has a max_obj_id of 90, so we pass `num_classes` to be 91. +# # As another example, for a dataset that has a single class with id 1, +# # you should pass `num_classes` to be 2 (max_obj_id + 1). +# # For more details on this, check the following discussion +# # https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223 +# num_classes = 20 if args.dataset_file != 'coco' else 91 +# if args.dataset_file == "coco_panoptic": +# # for panoptic, we just add a num_classes that is large enough to hold +# # max_obj_id + 1, but the exact value doesn't really matter +# num_classes = 250 +# device = torch.device(args.device) +# +# backbone = build_backbone(args) +# +# transformer = build_transformer(args) +# +# model = DETR( +# backbone, +# transformer, +# num_classes=num_classes, +# num_queries=args.num_queries, +# aux_loss=args.aux_loss, +# ) +# +# matcher = build_matcher(args) +# weight_dict = {'loss_ce': 1, 'loss_bbox': args.bbox_loss_coef} +# weight_dict['loss_giou'] = args.giou_loss_coef +# +# # TODO this is a hack +# if args.aux_loss: +# aux_weight_dict = {} +# for i in range(args.dec_layers - 1): +# aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()}) +# weight_dict.update(aux_weight_dict) +# +# losses = ['labels', 'boxes', 'cardinality'] +# +# criterion = SetCriterion(num_classes, matcher=matcher, weight_dict=weight_dict, +# eos_coef=args.eos_coef, losses=losses) +# criterion.to(device) +# postprocessors = {'bbox': PostProcess()} +# +# return model, criterion, postprocessors diff --git a/examples/utils.py b/examples/utils.py index b4b1aa69..1a20ab45 100644 --- a/examples/utils.py +++ b/examples/utils.py @@ -178,7 +178,7 @@ def initialize_wandb(config): def save_pred(y_pred, path_prefix): # Single tensor - if torch.is_tensor(y_pred): + if torch.is_tensor(y_pred): df = pd.DataFrame(y_pred.numpy()) df.to_csv(path_prefix + '.csv', index=False, header=False) # Dictionary diff --git a/wilds/common/metrics/metric.py b/wilds/common/metrics/metric.py index 3f42a5af..8f5ee416 100644 --- a/wilds/common/metrics/metric.py +++ b/wilds/common/metrics/metric.py @@ -131,8 +131,9 @@ def _compute_group_wise(self, y_pred, y_true, g, n_groups): else: group_metrics.append( self._compute( - y_pred[g == group_idx], - y_true[g == group_idx])) + get_subset_from_mask(y_pred, g == group_idx), + get_subset_from_mask(y_true, g == group_idx))) + group_metrics = torch.stack(group_metrics) worst_group_metric = self.worst(group_metrics[group_counts>0]) diff --git a/wilds/common/utils.py b/wilds/common/utils.py index fc7f438b..e3b1440d 100644 --- a/wilds/common/utils.py +++ b/wilds/common/utils.py @@ -112,7 +112,6 @@ def subsample_idxs(idxs, num=5000, take_rest=False, seed=None): idxs = idxs[:num] return idxs - def shuffle_arr(arr, seed=None): seed = (seed + 548207) if seed is not None else None rng = np.random.default_rng(seed) @@ -133,3 +132,18 @@ def numel(obj): return len(obj) else: raise TypeError("Invalid type for numel") + +# def get_subset_from_mask(seq, mask): +# """ +# Mask should be a binary vector with the same length as seq. +# """ +# if torch.is_tensor(seq) or isinstance(seq, list): +# if len(mask) != len(seq): +# print(len(mask)) +# print(len(seq)) +# raise ValueError('Mask must have same length as the input.') +# return seq[mask] +# elif isinstance(seq, dict): +# return {k: get_subset_from_mask(v, mask) for k, v in seq.items()} +# else: +# raise TypeError('Input must be a Tensor, list, or dict.') diff --git a/wilds/datasets/gwhd_dataset.py b/wilds/datasets/gwhd_dataset.py index ad1a7979..406271f6 100644 --- a/wilds/datasets/gwhd_dataset.py +++ b/wilds/datasets/gwhd_dataset.py @@ -69,7 +69,8 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' self._original_resolution = (1024, 1024) self.root = Path(self.data_dir) self._is_detection = True - self._y_size = 1 + self._is_classification = False + self._y_size = None self._n_classes = 1 self._split_scheme = split_scheme @@ -109,7 +110,7 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' torch.tensor([int(i) for i in box.split(" ")]) for box in boxes.split(";") ]), - "labels": torch.tensor([1.]*len(list(boxes.split(";")))).long() + "labels": torch.tensor([0]*len(list(boxes.split(";")))).long() } if type(boxes) != float else { "boxes": torch.empty(0,4), # "labels": torch.empty(0,1,dtype=torch.long) @@ -129,16 +130,16 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' label['boxes'] = torch.stack((center_x, center_y, width, height), dim=1) # num_boxes = [len(example['boxes']) for example in labels] - # print(f'Max num_boxes is {max(num_boxes)}') + # print(f'Max num_boxes is {max(num_boxes)}') self._y_array.extend(labels) self._metadata_array.extend(list(df['group'].values)) self._split_array = np.array(self._split_array) - self._metadata_fields = ['location'] self._metadata_array = torch.tensor(self._metadata_array, dtype=torch.long).unsqueeze(1) + self._metadata_fields = ['location'] self._eval_grouper = CombinatorialGrouper( dataset=self, @@ -147,19 +148,14 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' self._metric = DummyMetric() # TODO self._collate = _collate_fn + super().__init__(root_dir, download, split_scheme) + def get_input(self, idx): """ Returns x for a given idx. """ img_filename = self.root / "images" / self._image_array[idx] x = Image.open(img_filename) - # - # import psutil - # for proc in psutil.process_iter(): - # try: - # print(proc.open_files()) - # except (psutil.AccessDenied): - # pass return x def eval(self, y_pred, y_true, metadata): diff --git a/wilds/datasets/wilds_dataset.py b/wilds/datasets/wilds_dataset.py index b650e556..8bf0128d 100644 --- a/wilds/datasets/wilds_dataset.py +++ b/wilds/datasets/wilds_dataset.py @@ -95,8 +95,8 @@ def check_init(self): assert 'train' in self.split_dict assert 'val' in self.split_dict - # Check that required arrays are Tensors - assert isinstance(self.y_array, torch.Tensor), 'y_array must be a torch.Tensor' + # Check the form of the required arrays + assert (isinstance(self.y_array, torch.Tensor) or isinstance(self.y_array, list)) assert isinstance(self.metadata_array, torch.Tensor), 'metadata_array must be a torch.Tensor' # Check that dimensions match @@ -247,7 +247,7 @@ def is_classification(self): """ Boolean. True if the task is classification, and false otherwise. """ - return (self.n_classes is not None) + return getattr(self, '_is_classification', (self.n_classes is not None)) @property def is_detection(self): From 524974b4aa0205002e6f7db0818e43f5ddd3af1e Mon Sep 17 00:00:00 2001 From: Etienne David Date: Wed, 17 Mar 2021 20:52:39 +0100 Subject: [PATCH 037/244] implementation of detection accuracy --- .gitignore | 2 + examples/configs/datasets.py | 2 +- wilds/common/metrics/all_metrics.py | 77 +++++++++++++++++++++++++++++ wilds/datasets/gwhd_dataset.py | 4 +- 4 files changed, 82 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index ac33582a..e1076393 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,5 @@ __pycache__ build dist wilds.egg-info +data +logs \ No newline at end of file diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index 99b251ce..a201e4bd 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -294,7 +294,7 @@ 'pretrained': True}, 'loss_function': 'detr_set_criterion', 'groupby_fields': ['location'], - 'val_metric': 'dummy_all', # TODO + 'val_metric': 'detection_accuracy_avg', # TODO 'val_metric_decreasing': False, 'algo_log_metric': None, # TODO 'optimizer': 'Adam', diff --git a/wilds/common/metrics/all_metrics.py b/wilds/common/metrics/all_metrics.py index d3c78592..ddf2baf0 100644 --- a/wilds/common/metrics/all_metrics.py +++ b/wilds/common/metrics/all_metrics.py @@ -1,5 +1,8 @@ import torch import torch.nn as nn +from torchvision.ops.boxes import box_iou +from torchvision.models.detection._utils import Matcher +from torchvision.ops import nms, box_convert import numpy as np import torch.nn.functional as F from wilds.common.metrics.metric import Metric, ElementwiseMetric, MultiTaskMetric @@ -161,3 +164,77 @@ def _compute_group_wise(self, y_pred, y_true, g, n_groups): def worst(self, metrics): return minimum(metrics) + +class DetectionAccuracy(ElementwiseMetric): + """ + Given a specific Intersection over union threshold, + determine the accuracy achieved for a one-class detector + """ + def __init__(self, prediction_fn=None, iou_threshold=0.5,score_threshold=0.5, name=None): + self.prediction_fn = prediction_fn + self.iou_threshold = iou_threshold + self.score_threshold = score_threshold + if name is None: + name = "detection_accuracy" + super().__init__(name=name) + + def _compute_element_wise(self, y_pred ,y_true ): + + + + batch_results = [] + for src_boxes, target_boxes, target_logits in zip( y_true, y_pred['pred_boxes'], y_pred['pred_logits']): + + # Here should be prediction_fn ? + + target_scores = F.softmax(target_logits, dim=1)[..., 0] + pred_boxes = target_boxes[target_scores > self.score_threshold] + + det_accuracy = self._accuracy(src_boxes["boxes"],pred_boxes,iou_threshold=self.iou_threshold) + batch_results.append(det_accuracy) + + return torch.tensor(batch_results) + + + def _accuracy(self, src_boxes,pred_boxes , iou_threshold = 1.): + total_gt = len(src_boxes) + total_pred = len(pred_boxes) + + + if total_gt > 0 and total_pred > 0: + + # Define the matcher and distance matrix based on iou + matcher = Matcher(iou_threshold,iou_threshold,allow_low_quality_matches=False) + + src_boxes = box_convert(src_boxes , "cxcywh" ,"xyxy") + pred_boxes = box_convert(pred_boxes , "cxcywh" ,"xyxy") + + + match_quality_matrix = box_iou(src_boxes,pred_boxes) + + results = matcher(match_quality_matrix) + + true_positive = torch.count_nonzero(results.unique() != -1) + matched_elements = results[results > -1] + + #in Matcher, a pred element can be matched only twice + false_positive = torch.count_nonzero(results == -1) + ( len(matched_elements) - len(matched_elements.unique())) + false_negative = total_gt - true_positive + + + return true_positive / ( true_positive + false_positive + false_negative ) + + elif total_gt == 0: + if total_pred > 0: + return torch.tensor(0.) + else: + return torch.tensor(1.) + elif total_gt > 0 and total_pred == 0: + return torch.tensor(0.) + + + + def worst(self, metrics): + return minimum(metrics) + + diff --git a/wilds/datasets/gwhd_dataset.py b/wilds/datasets/gwhd_dataset.py index 406271f6..151f9a0c 100644 --- a/wilds/datasets/gwhd_dataset.py +++ b/wilds/datasets/gwhd_dataset.py @@ -5,7 +5,7 @@ from PIL import Image from wilds.datasets.wilds_dataset import WILDSDataset from wilds.common.grouper import CombinatorialGrouper -from wilds.common.metrics.all_metrics import DummyMetric +from wilds.common.metrics.all_metrics import DetectionAccuracy def _collate_fn(batch): @@ -145,7 +145,7 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' dataset=self, groupby_fields=['location']) - self._metric = DummyMetric() # TODO + self._metric = DetectionAccuracy() # TODO self._collate = _collate_fn super().__init__(root_dir, download, split_scheme) From 713c9793052eb2e8b9b82541e0b3941240e5119b Mon Sep 17 00:00:00 2001 From: aikanor Date: Thu, 18 Mar 2021 19:06:08 -0700 Subject: [PATCH 038/244] integration besides eval --- examples/sbox_run_expt.ipynb | 465 +++++++++------------------- wilds/common/metrics/all_metrics.py | 11 +- wilds/common/metrics/metric.py | 2 +- 3 files changed, 162 insertions(+), 316 deletions(-) diff --git a/examples/sbox_run_expt.ipynb b/examples/sbox_run_expt.ipynb index 5040aeb0..86525331 100644 --- a/examples/sbox_run_expt.ipynb +++ b/examples/sbox_run_expt.ipynb @@ -11,14 +11,21 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "4860.4765625\n" + "396.69921875\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:root:The OGB package is out of date. Your version is 1.2.4, while the latest version is 1.3.0.\n" ] } ], @@ -69,7 +76,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "WARNING:root:The OGB package is out of date. Your version is 1.2.4, while the latest version is 1.2.6.\n" + "WARNING:root:The WILDS package is out of date. Your version is 1.0.0, while the latest version is 1.1.0.\n", + "WARNING:root:The OGB package is out of date. Your version is 1.2.4, while the latest version is 1.3.0.\n" ] } ], @@ -105,7 +113,7 @@ { "data": { "text/plain": [ - "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" + "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" ] }, "execution_count": 2, @@ -208,23 +216,38 @@ "outputs": [], "source": [ "argstr_camelyon = \"--dataset camelyon17 --algorithm ERM --root_dir data\"\n", - "# argstr_camelyon = \"--dataset civilcomments --algorithm ERM --root_dir data\"\n", "config_camelyon = parser.parse_args(argstr_camelyon.split())\n", "config_camelyon = populate_defaults(config_camelyon)\n", "\n", + "argstr_bdd100k = \"--dataset bdd100k --algorithm ERM --root_dir data\"\n", + "config_bdd100k = parser.parse_args(argstr_bdd100k.split())\n", + "config_bdd100k = populate_defaults(config_bdd100k)\n", + "\n", "argstr_encode = \"--dataset encode-tfbs --algorithm ERM --root_dir data\"\n", "config_encode = parser.parse_args(argstr_encode.split())\n", "config_encode = populate_defaults(config_encode)\n", "\n", "config = config_camelyon\n", - "config = config_encode\n" + "config = config_encode\n", + "config = config_bdd100k\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "Namespace(algo_log_metric=None, algorithm='ERM', batch_size=None, coral_penalty_weight=None, dataset='encode-tfbs', dataset_kwargs={}, device=0, distinct_groups=None, download=False, eval_epoch=None, eval_loader='standard', eval_only=False, eval_splits=[], eval_transform=None, evaluate_all_splits=True, frac=1.0, group_dro_step_size=None, groupby_fields=None, irm_lambda=None, irm_penalty_anneal_iters=None, loader_kwargs={'num_workers': 1, 'pin_memory': True}, log_dir='./logs', log_every=50, loss_function=None, lr=None, max_grad_norm=None, max_token_length=None, model=None, model_kwargs={'pretrained': False}, n_epochs=None, n_groups_per_batch=None, no_group_logging=None, optimizer=None, optimizer_kwargs={'momentum': 0.9}, progress_bar=False, resize_scale=None, resume=False, root_dir='data', save_best=True, save_last=True, save_step=None, scheduler=None, scheduler_kwargs={}, scheduler_metric_name=None, scheduler_metric_split='val', seed=0, split_scheme=None, target_resolution=None, train_loader=None, train_transform=None, uniform_over_groups=None, use_wandb=False, val_metric=None, val_metric_decreasing=None, weight_decay=None)" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "argstr_camelyon = \"--dataset camelyon17 --algorithm ERM --root_dir data\"\n", "# argstr_camelyon = \"--dataset civilcomments --algorithm ERM --root_dir data\"\n", @@ -237,23 +260,34 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "Namespace(algo_log_metric='multitask_accuracy', algorithm='ERM', batch_size=32, coral_penalty_weight=None, dataset='bdd100k', dataset_kwargs={}, device=0, distinct_groups=None, download=False, eval_epoch=None, eval_loader='standard', eval_only=False, eval_splits=[], eval_transform='image_base', evaluate_all_splits=True, frac=1.0, group_dro_step_size=None, groupby_fields=None, irm_lambda=None, irm_penalty_anneal_iters=None, loader_kwargs={'num_workers': 1, 'pin_memory': True}, log_dir='./logs', log_every=50, loss_function='multitask_bce', lr=0.001, max_grad_norm=None, max_token_length=None, model='resnet50', model_kwargs={'pretrained': False}, n_epochs=10, n_groups_per_batch=4, no_group_logging=True, optimizer='SGD', optimizer_kwargs={'momentum': 0.9}, progress_bar=False, resize_scale=None, resume=False, root_dir='data', save_best=True, save_last=True, save_step=None, scheduler=None, scheduler_kwargs={}, scheduler_metric_name=None, scheduler_metric_split='val', seed=0, split_scheme='official', target_resolution=(224, 224), train_loader='standard', train_transform='image_base', uniform_over_groups=False, use_wandb=False, val_metric='acc_all', val_metric_decreasing=False, weight_decay=0.0001)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "config.optimizer_kwargs = {}" + "config#.optimizer_kwargs = {}" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Dataset: encode-tfbs\n", + "Dataset: bdd100k\n", "Algorithm: ERM\n", "Root dir: data\n", "Split scheme: official\n", @@ -264,31 +298,31 @@ "Train loader: standard\n", "Uniform over groups: False\n", "Distinct groups: None\n", - "N groups per batch: 2\n", - "Batch size: 64\n", + "N groups per batch: 4\n", + "Batch size: 32\n", "Eval loader: standard\n", - "Model: leopard\n", + "Model: resnet50\n", "Model kwargs: {'pretrained': False}\n", - "Train transform: None\n", - "Eval transform: None\n", - "Target resolution: None\n", + "Train transform: image_base\n", + "Eval transform: image_base\n", + "Target resolution: (224, 224)\n", "Resize scale: None\n", "Max token length: None\n", "Loss function: multitask_bce\n", - "Groupby fields: ['celltype']\n", + "Groupby fields: None\n", "Group dro step size: None\n", "Coral penalty weight: None\n", "Irm lambda: None\n", "Irm penalty anneal iters: None\n", - "Algo log metric: multitask_avgprec\n", - "Val metric: acc_avg\n", + "Algo log metric: multitask_accuracy\n", + "Val metric: acc_all\n", "Val metric decreasing: False\n", - "N epochs: 5\n", - "Optimizer: Adam\n", + "N epochs: 10\n", + "Optimizer: SGD\n", "Lr: 0.001\n", - "Weight decay: 0.01\n", + "Weight decay: 0.0001\n", "Max grad norm: None\n", - "Optimizer kwargs: {}\n", + "Optimizer kwargs: {'momentum': 0.9}\n", "Scheduler: None\n", "Scheduler kwargs: {}\n", "Scheduler metric split: val\n", @@ -304,14 +338,11 @@ "Save step: None\n", "Save best: True\n", "Save last: True\n", - "No group logging: False\n", + "No group logging: True\n", "Use wandb: False\n", "Progress bar: False\n", "Resume: False\n", - "\n", - "chr3 2.9614717960357666\n", - "chr2 6.587897777557373\n", - "chr1 10.29332971572876\n" + "\n" ] } ], @@ -359,30 +390,6 @@ " dataset=full_dataset)" ] }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import copy\n", - "full_dataset_camelyon17 = copy.deepcopy(full_dataset)\n", - "\n", - "# supported.datasets[config_encode.dataset]\n", - "# print(config_camelyon.train_transform, config_encode.train_transform)\n" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -392,11 +399,12 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": { "collapsed": true, "jupyter": { - "outputs_hidden": true + "outputs_hidden": true, + "source_hidden": true } }, "outputs": [ @@ -404,36 +412,33 @@ "name": "stdout", "output_type": "stream", "text": [ - "chr3 3.0055365562438965\n", - "chr4 5.905960321426392\n", - "chr5 8.651455879211426\n", - "chr6 11.250766038894653\n", - "chr7 13.660939931869507\n", - "chr10 15.713522672653198\n", - "chr12 17.740623474121094\n", - "chr13 19.478207111358643\n", - "chr14 21.088634252548218\n", - "chr15 22.625713348388672\n", - "chr16 23.987269639968872\n", - "chr17 25.21428894996643\n", - "chr18 26.394341230392456\n", - "chr19 27.28497076034546\n", - "chr20 28.235496282577515\n", - "chr22 28.999913692474365\n", - "chrX 31.338406085968018\n", - "chr2 35.00527381896973\n", - "chr9 37.12277841567993\n", - "chr11 39.157737016677856\n", - "chr1 42.89226841926575\n", - "chr8 45.092690229415894\n", - "chr21 45.81230306625366\n", - "H1-hESC 45.81402635574341\n", - "HCT116 45.814292192459106\n", - "HeLa-S3 45.814526081085205\n", - "HepG2 45.814810276031494\n", - "K562 45.815062522888184\n", - "A549 45.81636619567871\n", - "GM12878 45.81674289703369\n" + "chr3 3.0039219856262207\n", + "chr4 5.89985990524292\n", + "chr5 8.640583038330078\n", + "chr6 11.237342596054077\n", + "chr7 13.666043519973755\n", + "chr10 15.858035326004028\n", + "chr12 17.94972252845764\n", + "chr13 19.689449071884155\n", + "chr14 21.30842876434326\n", + "chr15 22.856398582458496\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[0m_seq_bp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 57\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mchrom\u001b[0m \u001b[0;32min\u001b[0m \u001b[0m_all_chroms\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 58\u001b[0;31m \u001b[0m_seq_bp\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mchrom\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mseq_arr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mchrom\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 59\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mchrom\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mitime\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/site-packages/numpy/lib/npyio.py\u001b[0m in \u001b[0;36m__getitem__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 252\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mmagic\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mformat\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mMAGIC_PREFIX\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 253\u001b[0m \u001b[0mbytes\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzip\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 254\u001b[0;31m return format.read_array(bytes,\n\u001b[0m\u001b[1;32m 255\u001b[0m \u001b[0mallow_pickle\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mallow_pickle\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 256\u001b[0m pickle_kwargs=self.pickle_kwargs)\n", + "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/site-packages/numpy/lib/format.py\u001b[0m in \u001b[0;36mread_array\u001b[0;34m(fp, allow_pickle, pickle_kwargs)\u001b[0m\n\u001b[1;32m 773\u001b[0m \u001b[0mread_count\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmax_read_count\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcount\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 774\u001b[0m \u001b[0mread_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mread_count\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitemsize\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 775\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_read_bytes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mread_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"array data\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 776\u001b[0m array[i:i+read_count] = numpy.frombuffer(data, dtype=dtype,\n\u001b[1;32m 777\u001b[0m count=read_count)\n", + "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/site-packages/numpy/lib/format.py\u001b[0m in \u001b[0;36m_read_bytes\u001b[0;34m(fp, size, error_template)\u001b[0m\n\u001b[1;32m 902\u001b[0m \u001b[0;31m# done about that. note that regular files can't be non-blocking\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 903\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 904\u001b[0;31m \u001b[0mr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msize\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 905\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mr\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 906\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mr\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0msize\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/zipfile.py\u001b[0m in \u001b[0;36mread\u001b[0;34m(self, n)\u001b[0m\n\u001b[1;32m 938\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_offset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 939\u001b[0m \u001b[0;32mwhile\u001b[0m \u001b[0mn\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m0\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_eof\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 940\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_read1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 941\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mn\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 942\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_readbuffer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/zipfile.py\u001b[0m in \u001b[0;36m_read1\u001b[0;34m(self, n)\u001b[0m\n\u001b[1;32m 1028\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_left\u001b[0m \u001b[0;34m<=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1029\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_eof\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1030\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_update_crc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1031\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1032\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/zipfile.py\u001b[0m in \u001b[0;36m_update_crc\u001b[0;34m(self, newdata)\u001b[0m\n\u001b[1;32m 953\u001b[0m \u001b[0;31m# No need to compute the CRC if we don't have a reference value\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 954\u001b[0m \u001b[0;32mreturn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 955\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_running_crc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcrc32\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnewdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_running_crc\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 956\u001b[0m \u001b[0;31m# Check the CRC if we're at the end of the file\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 957\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_eof\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_running_crc\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_expected_crc\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], @@ -522,8 +527,12 @@ }, { "cell_type": "code", - "execution_count": 7, - "metadata": {}, + "execution_count": null, + "metadata": { + "jupyter": { + "source_hidden": true + } + }, "outputs": [], "source": [ "train_regions_mask = np.isin(_metadata_df['chr'], _train_chroms)\n", @@ -566,7 +575,11 @@ { "cell_type": "code", "execution_count": 325, - "metadata": {}, + "metadata": { + "jupyter": { + "source_hidden": true + } + }, "outputs": [], "source": [ "def get_random_label_vec(\n", @@ -609,7 +622,11 @@ { "cell_type": "code", "execution_count": 24, - "metadata": {}, + "metadata": { + "jupyter": { + "source_hidden": true + } + }, "outputs": [], "source": [ "import os, time\n", @@ -850,7 +867,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -858,38 +875,12 @@ "output_type": "stream", "text": [ "Train data...\n", - " celltype = H1-hESC: n = 5314\n", - " celltype = HCT116: n = 4759\n", - " celltype = HeLa-S3: n = 4635\n", - " celltype = HepG2: n = 4459\n", - " celltype = K562: n = 5169\n", - " celltype = A549: n = 0\n", - " celltype = GM12878: n = 0\n", - "Validation (ID) data...\n", - " celltype = H1-hESC: n = 6872\n", - " celltype = HCT116: n = 6315\n", - " celltype = HeLa-S3: n = 4219\n", - " celltype = HepG2: n = 8356\n", - " celltype = K562: n = 6538\n", - " celltype = A549: n = 0\n", - " celltype = GM12878: n = 0\n", + " n = 64993\n", + "Validation data...\n", + " n = 4860\n", "Test data...\n", - " celltype = H1-hESC: n = 0\n", - " celltype = HCT116: n = 0\n", - " celltype = HeLa-S3: n = 0\n", - " celltype = HepG2: n = 0\n", - " celltype = K562: n = 0\n", - " celltype = A549: n = 0\n", - " celltype = GM12878: n = 4487\n", - "Validation (OOD) data...\n", - " celltype = H1-hESC: n = 0\n", - " celltype = HCT116: n = 0\n", - " celltype = HeLa-S3: n = 0\n", - " celltype = HepG2: n = 0\n", - " celltype = K562: n = 0\n", - " celltype = A549: n = 6728\n", - " celltype = GM12878: n = 0\n", - "Dout: 128\n" + " n = 4742\n", + "Dout: 9\n" ] } ], @@ -969,7 +960,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -979,16 +970,6 @@ "# x = torch.transpose(x, 1, 2)" ] }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "d = algorithm.process_batch(batch)\n", - "# algorithm.loss.compute" - ] - }, { "cell_type": "code", "execution_count": 9, @@ -997,7 +978,7 @@ { "data": { "text/plain": [ - "tensor(0.7212, device='cuda:0', grad_fn=)" + "tensor(0.8208, device='cuda:0', grad_fn=)" ] }, "execution_count": 9, @@ -1006,6 +987,8 @@ } ], "source": [ + "d = algorithm.process_batch(batch)\n", + "\n", "a = algorithm.loss.compute(d['y_pred'], d['y_true'], return_dict=False)\n", "a" ] @@ -1017,141 +1000,8 @@ "outputs": [ { "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
chrstartstopcelltypesplit
39413chr21000320010009600A5493
39414chr2100032000100038400A5493
39415chr2100102400100108800A5493
39416chr2100172800100179200A5493
39417chr2100230400100236800A5493
..................
495287chr3999680010003200K5620
495288chr39997440099980800K5620
495289chr39998080099987200K5620
495290chr39998720099993600K5620
495291chr399993600100000000K5620
\n", - "

67851 rows × 5 columns

\n", - "
" - ], "text/plain": [ - " chr start stop celltype split\n", - "39413 chr2 10003200 10009600 A549 3\n", - "39414 chr2 100032000 100038400 A549 3\n", - "39415 chr2 100102400 100108800 A549 3\n", - "39416 chr2 100172800 100179200 A549 3\n", - "39417 chr2 100230400 100236800 A549 3\n", - "... ... ... ... ... ...\n", - "495287 chr3 9996800 10003200 K562 0\n", - "495288 chr3 99974400 99980800 K562 0\n", - "495289 chr3 99980800 99987200 K562 0\n", - "495290 chr3 99987200 99993600 K562 0\n", - "495291 chr3 99993600 100000000 K562 0\n", - "\n", - "[67851 rows x 5 columns]" + "" ] }, "execution_count": 10, @@ -1161,7 +1011,7 @@ ], "source": [ "#np.unique(full_dataset._metadata_df['split'], return_counts=True)\n", - "full_dataset._metadata_df" + "full_dataset" ] }, { @@ -1170,38 +1020,20 @@ "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "(array([0. , 0.5, 1. ], dtype=float32), array([7422683, 1007200, 255045]))" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "np.unique(full_dataset.y_array, return_counts=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "0.8546625832706961" - ] - }, - "execution_count": 26, - "metadata": {}, - "output_type": "execute_result" + "ename": "NameError", + "evalue": "name 'importlib' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m#import importlib\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mimportlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mNameError\u001b[0m: name 'importlib' is not defined" + ] } ], "source": [ - "7422683/8684928" + "#import importlib\n", + "importlib.reload(train)" ] }, { @@ -1223,27 +1055,25 @@ "\n", "Epoch [0]:\n", "\n", - "Train:\n", - "torch.Size([8192]) torch.Size([8192]) torch.Size([64, 128]) torch.Size([64, 128])\n", - "torch.Size([]) torch.Size([8192]) torch.Size([64, 128]) torch.Size([64, 128])\n" + "Train:\n" ] }, { - "ename": "AssertionError", + "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 26\u001b[0;31m train(\n\u001b[0m\u001b[1;32m 27\u001b[0m \u001b[0malgorithm\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0malgorithm\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0mdatasets\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdatasets\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0mbest_val_metric\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 25\u001b[0;31m train(\n\u001b[0m\u001b[1;32m 26\u001b[0m \u001b[0malgorithm\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0malgorithm\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0mdatasets\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdatasets\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/wilds/examples/train.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(algorithm, datasets, general_logger, config, epoch_offset, best_val_metric)\u001b[0m\n\u001b[1;32m 88\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 89\u001b[0m \u001b[0;31m# First run training\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 90\u001b[0;31m \u001b[0mrun_epoch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0malgorithm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdatasets\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'train'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgeneral_logger\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepoch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 91\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[0;31m# Then run val\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/wilds/examples/train.py\u001b[0m in \u001b[0;36mrun_epoch\u001b[0;34m(algorithm, dataset, general_logger, epoch, config, train)\u001b[0m\n\u001b[1;32m 41\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0miterator\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 43\u001b[0;31m \u001b[0mbatch_results\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0malgorithm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 44\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[0mbatch_results\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0malgorithm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mevaluate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/wilds/examples/algorithms/single_model_algorithm.py\u001b[0m in \u001b[0;36mupdate\u001b[0;34m(self, batch)\u001b[0m\n\u001b[1;32m 103\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_update\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 104\u001b[0m \u001b[0;31m# log results\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 105\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate_log\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 106\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msanitize_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 107\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/wilds/examples/algorithms/group_algorithm.py\u001b[0m in \u001b[0;36mupdate_log\u001b[0;34m(self, results)\u001b[0m\n\u001b[1;32m 47\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mm\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlogged_metrics\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mno_group_logging\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 49\u001b[0;31m group_metrics, group_counts, worst_group_metric = m.compute_group_wise(\n\u001b[0m\u001b[1;32m 50\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'y_pred'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'y_true'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/wilds/wilds/common/metrics/metric.py\u001b[0m in \u001b[0;36mcompute_group_wise\u001b[0;34m(self, y_pred, y_true, g, n_groups, return_dict)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mresults\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mdict\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mDictionary\u001b[0m \u001b[0mof\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 113\u001b[0m \"\"\"\n\u001b[0;32m--> 114\u001b[0;31m \u001b[0mgroup_metrics\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgroup_counts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mworst_group_metric\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compute_group_wise\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_groups\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 115\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mreturn_dict\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/wilds/wilds/common/metrics/metric.py\u001b[0m in \u001b[0;36m_compute_group_wise\u001b[0;34m(self, y_pred, y_true, g, n_groups)\u001b[0m\n\u001b[1;32m 234\u001b[0m \u001b[0mflattened_g\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mg\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mindices\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 235\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mflattened_metrics\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mflattened_g\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_pred\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 236\u001b[0;31m \u001b[0mgroup_metrics\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgroup_counts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mavg_over_groups\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mflattened_metrics\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mflattened_g\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_groups\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 237\u001b[0m \u001b[0mworst_group_metric\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mworst\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgroup_metrics\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mgroup_counts\u001b[0m\u001b[0;34m>\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 238\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mgroup_metrics\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgroup_counts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mworst_group_metric\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/wilds/wilds/common/utils.py\u001b[0m in \u001b[0;36mavg_over_groups\u001b[0;34m(v, g, n_groups)\u001b[0m\n\u001b[1;32m 84\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0mg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 85\u001b[0m \u001b[0mdevice\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 86\u001b[0;31m \u001b[0;32massert\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0mg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 87\u001b[0m \u001b[0mgroup_count\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_counts\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_groups\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 88\u001b[0m \u001b[0mgroup_avgs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch_scatter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscatter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msrc\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindex\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mn_groups\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduce\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'mean'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mAssertionError\u001b[0m: " + "\u001b[0;32m~/wilds/examples/algorithms/group_algorithm.py\u001b[0m in \u001b[0;36mupdate_log\u001b[0;34m(self, results)\u001b[0m\n\u001b[1;32m 54\u001b[0m return_dict=False)\n\u001b[1;32m 55\u001b[0m \u001b[0mbatch_log\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34mf'{self.group_prefix}{m.name}'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgroup_metrics\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 56\u001b[0;31m batch_log[m.agg_metric_field] = m.compute(\n\u001b[0m\u001b[1;32m 57\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'y_pred'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'y_true'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/wilds/wilds/common/metrics/metric.py\u001b[0m in \u001b[0;36mcompute\u001b[0;34m(self, y_pred, y_true, return_dict)\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[0magg_metric\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0.\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0my_true\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 87\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 88\u001b[0;31m \u001b[0magg_metric\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compute\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 89\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mreturn_dict\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 90\u001b[0m results = {\n", + "\u001b[0;32m~/wilds/wilds/common/metrics/metric.py\u001b[0m in \u001b[0;36m_compute\u001b[0;34m(self, y_pred, y_true)\u001b[0m\n\u001b[1;32m 224\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 225\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_compute\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 226\u001b[0;31m \u001b[0mflattened_metrics\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompute_flattened\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreturn_dict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 227\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mflattened_metrics\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 228\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0.\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0my_true\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/wilds/wilds/common/metrics/metric.py\u001b[0m in \u001b[0;36mcompute_flattened\u001b[0;34m(self, y_pred, y_true, return_dict)\u001b[0m\n\u001b[1;32m 240\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcompute_flattened\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreturn_dict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 241\u001b[0m \u001b[0mis_labeled\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m~\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misnan\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_true\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 242\u001b[0;31m \u001b[0mbatch_idx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwhere\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mis_labeled\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 243\u001b[0m \u001b[0mflattened_y_pred\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0my_pred\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mis_labeled\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 244\u001b[0m \u001b[0mflattened_y_true\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mis_labeled\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], @@ -1271,8 +1101,7 @@ " if resume_success == False:\n", " epoch_offset=0\n", " best_val_metric=None\n", - "\n", - "\n", + " \n", " train(\n", " algorithm=algorithm,\n", " datasets=datasets,\n", @@ -1324,6 +1153,20 @@ "outputs": [], "source": [] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": null, diff --git a/wilds/common/metrics/all_metrics.py b/wilds/common/metrics/all_metrics.py index 0d84f1ad..9a0d4de1 100644 --- a/wilds/common/metrics/all_metrics.py +++ b/wilds/common/metrics/all_metrics.py @@ -74,17 +74,20 @@ def __init__(self, prediction_fn=logits_to_binary_pred, name=None, average='macr def _compute_flattened(self, flattened_y_pred, flattened_y_true): if self.prediction_fn is not None: flattened_y_pred = self.prediction_fn(flattened_y_pred) + ytr = np.array(flattened_y_true.squeeze().detach().cpu().numpy() > 0) + ypr = flattened_y_pred.squeeze().detach().cpu().numpy() score = sklearn.metrics.average_precision_score( - np.array(flattened_y_true.squeeze().detach().cpu().numpy() > 0), - flattened_y_pred.squeeze().detach().cpu().numpy(), + ytr, + ypr, average=self.average ) - return torch.tensor(score).to(flattened_y_pred.device) + to_ret = torch.tensor(score).to(flattened_y_pred.device) + print("why ", ytr, ytr.shape, ypr, ypr.shape, score, to_ret) + return to_ret def worst(self, metrics): return minimum(metrics) - class Recall(Metric): def __init__(self, prediction_fn=logits_to_pred, name=None, average='binary'): self.prediction_fn = prediction_fn diff --git a/wilds/common/metrics/metric.py b/wilds/common/metrics/metric.py index 281696d8..207dacf6 100644 --- a/wilds/common/metrics/metric.py +++ b/wilds/common/metrics/metric.py @@ -232,7 +232,7 @@ def _compute(self, y_pred, y_true): def _compute_group_wise(self, y_pred, y_true, g, n_groups): flattened_metrics, indices = self.compute_flattened(y_pred, y_true, return_dict=False) flattened_g = g[indices] - print(flattened_metrics.shape, flattened_g.shape, y_pred.shape, y_true.shape) + print(flattened_metrics.shape, flattened_g.shape, (indices > 0).sum(), y_pred.shape, y_true.shape) group_metrics, group_counts = avg_over_groups(flattened_metrics, flattened_g, n_groups) worst_group_metric = self.worst(group_metrics[group_counts>0]) return group_metrics, group_counts, worst_group_metric From 9646f588aceb1d4bb346b72c652fdd21c46366f6 Mon Sep 17 00:00:00 2001 From: Tony Lee Date: Fri, 19 Mar 2021 06:22:45 -0700 Subject: [PATCH 039/244] Added evaluation script --- .gitignore | 1 + examples/configs/datasets.py | 15 ++ examples/evaluate.py | 262 +++++++++++++++++++++++++++++++++++ mypy.ini | 8 ++ requirements.dev.txt | 3 + 5 files changed, 289 insertions(+) create mode 100644 examples/evaluate.py create mode 100644 mypy.ini create mode 100644 requirements.dev.txt diff --git a/.gitignore b/.gitignore index ac33582a..1d3b5479 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ __pycache__ build dist +venv wilds.egg-info diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index cd2d1d6f..fedbbde9 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -284,6 +284,21 @@ }, } +####################################### +### List of the main WILDS datasets ### +####################################### + +main_datasets = [ + 'amazon', + 'camelyon17', + 'civilcomments', + 'fmow', + 'iwildcam', + 'ogb-molpcba', + 'py150', + 'poverty', +] + ########################################## ### Split-specific defaults for Amazon ### ########################################## diff --git a/examples/evaluate.py b/examples/evaluate.py new file mode 100644 index 00000000..b8f0ff31 --- /dev/null +++ b/examples/evaluate.py @@ -0,0 +1,262 @@ +import argparse +import json +import os +import sys +import urllib.request +from ast import literal_eval +from typing import Any, Dict, List, Union +from urllib.parse import urlparse + +import numpy as np +import torch + +from configs.datasets import main_datasets +from wilds import get_dataset +from wilds.datasets.wilds_dataset import WILDSDataset, WILDSSubset + + +""" +Evaluate models and predictions for WILDS datasets. + +Usage: + + python examples/evaluate.py + python examples/evaluate.py --dataset + +""" + + +def evaluate_all(path: str, output_path: str, dataset_path: str): + """ + Evaluate for all the WILDS datasets. + + Parameters: + path (str): Path to the directory with models or predictions. Can be a URL + output_path (str): Output directory + dataset_path (str): Path to the dataset directory + """ + all_results: Dict[str, Dict[str, Dict[str, float]]] = dict() + for dataset in main_datasets: + all_results[dataset] = evaluate_multiple_replicates( + dataset, path, output_path, dataset_path + ) + + # Write out aggregated results to output file + print(f"Writing complete results to {output_path}...") + with open(os.path.join(output_path, f"all_results.json"), "w") as f: + json.dump(all_results, f, indent=4) + + +def evaluate_multiple_replicates( + dataset_name: str, path: str, output_path: str, dataset_path: str +) -> Dict[str, Dict[str, float]]: + """ + Evaluate across multiple replicates. + + Parameters: + dataset_name (str): Name of the dataset. See datasets.py for the complete list of datasets. + path (str): Path to the directory with models or predictions. Can be a URL. + output_path (str): Output directory + dataset_path (str): Path to the dataset directory + + Returns: + Metrics as a dictionary with metrics as the keys and metric values as the values + """ + + def get_splits(dataset_name: str) -> List[str]: + if dataset_name in ["amazon", "fmow", "iwildcam", "poverty", "py150"]: + return ["id_val", "id_test", "val", "test"] + elif dataset_name == "camelyon17": + return ["id_val", "val", "test"] + elif dataset_name in ["civilcomments", "ogb-molpcba"]: + return ["val", "test"] + else: + raise ValueError(f"Invalid dataset: {dataset_name}") + + def get_replicates(dataset_name: str) -> List[Union[str, int]]: + if dataset_name == "camelyon17": + return list(range(0, 10)) + elif dataset_name == "poverty": + return ["A", "B", "C", "D", "E"] + else: + return list(range(0, 3)) + + def get_best_prediction_filename( + dataset_name: str, split: str, replicate: Union[str, int] + ) -> str: + if dataset_name == "poverty": + return f"{dataset_name}_split:{split}_fold:{replicate}_epoch:best_pred.csv" + else: + return f"{dataset_name}_split:{split}_seed:{replicate}_epoch:best_pred.csv" + + def get_metrics(dataset_name: str) -> List[str]: + if "amazon" == dataset_name: + return ["10th_percentile_acc", "acc_avg"] + elif "camelyon17" == dataset_name: + return ["acc_avg"] + elif "civilcomments" == dataset_name: + return ["acc_wg", "acc_avg"] + elif "fmow" == dataset_name: + return ["acc_worst_region", "acc_avg"] + elif "iwildcam" == dataset_name: + return ["F1-macro_all", "acc_avg"] + elif "ogb-molpcba" == dataset_name: + return ["ap"] + elif "poverty" == dataset_name: + return ["r_wg", "r_all"] + elif "py150" == dataset_name: + return ["acc", "Acc (Overall)"] + else: + raise ValueError(f"Invalid dataset: {dataset_name}") + + replicates_results: Dict[str, Dict[str, List[float]]] = dict() + splits: List[str] = get_splits(dataset_name) + replicates: List[Union[str, int]] = get_replicates(dataset_name) + metrics: List[str] = get_metrics(dataset_name) + + # Store the results for each replicate + for split in splits: + replicates_results[split] = {} + for metric in metrics: + replicates_results[split][metric] = [] + + for replicate in replicates: + # TODO: do I need to set seed here? -Tony + # set_seed_for_dataset(dataset_name, seed=replicate) + + predictions_file = get_best_prediction_filename( + dataset_name, split, replicate + ) + print( + f"Processing split={split}, replicate={replicate}, predictions_file={predictions_file}..." + ) + full_path = os.path.join(path, predictions_file) + predicted_labels: List[Any] = get_predictions(full_path) + predicted_labels_tensor: torch.Tensor = torch.from_numpy( + np.array(predicted_labels) + ) + metric_results: Dict[str, float] = evaluate( + dataset_name, split, predicted_labels_tensor, dataset_path + ) + for metric in metrics: + replicates_results[split][metric].append(metric_results[metric]) + + aggregated_results: Dict[str, Dict[str, float]] = dict() + + # Aggregate results of replicates + for split in splits: + aggregated_results[split] = {} + for metric in metrics: + replicates_metric_values: List[float] = replicates_results[split][metric] + aggregated_results[split][f"{metric}_std"] = np.std( + replicates_metric_values, ddof=1 + ) + aggregated_results[split][metric] = np.mean(replicates_metric_values) + + # Write out aggregated results to output file + print(f"Writing aggregated results for {dataset_name} to {output_path}...") + with open(os.path.join(output_path, f"{dataset_name}_results.json"), "w") as f: + json.dump(aggregated_results, f, indent=4) + + return aggregated_results + + +def evaluate( + dataset_name: str, split: str, predicted_labels: torch.Tensor, dataset_path: str +) -> Dict[str, float]: + """ + Evaluate the given predictions and return the appropriate metrics. + + Parameters: + dataset_name (str): Name of the dataset. + predicted_labels (torch.Tensor): Predictions + dataset_path (str): Path to the dataset directory + + Returns: + Metrics as a dictionary with metrics as the keys and metric values as the values + """ + # Dataset will only be downloaded if it does not exist + dataset: WILDSDataset = get_dataset( + dataset=dataset_name, root_dir=dataset_path, download=True + ) + subset: WILDSSubset = dataset.get_subset(split) + true_labels: torch.Tensor = subset.y_array + metadata: torch.Tensor = subset.metadata_array + # Attempt to resize predicted_labels tensor to match true_labels tensor's shape + predicted_labels.resize_(true_labels.shape) + return dataset.eval(predicted_labels, true_labels, metadata)[0] + + +def get_predictions(path: str) -> List[Any]: + """ + Extract out the predictions from the file at path. + + Parameters: + path (str): Path to the file that has the predicted labels. Can be a URL. + + Return: + List of predictions. + """ + if is_path_url(path): + data = urllib.request.urlopen(path) + else: + file = open(path, mode="r") + data = file.readlines() + file.close() + + predicted_labels = [literal_eval(line.rstrip()) for line in data if line.rstrip()] + return predicted_labels + + +def is_path_url(path: str) -> bool: + """ + Returns True if the path is a URL. + """ + try: + result = urlparse(path) + return all([result.scheme, result.netloc, result.path]) + except: + return False + + +def main(): + if args.dataset: + evaluate_multiple_replicates( + args.dataset, args.path, args.output_path, args.dataset_path + ) + else: + print("A dataset was not specified. Evaluating for all WILDS datasets...") + evaluate_all(args.path, args.output_path, args.dataset_path) + print("\nDone.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Evaluate models and predictions for WILDS datasets." + ) + parser.add_argument( + "path", + type=str, + help="Path to prediction CSV files.", + ) + parser.add_argument( + "output_path", + type=str, + help="Path to output directory.", + ) + parser.add_argument( + "--dataset", + type=str, + choices=main_datasets, + help="WILDS dataset to evaluate for.", + ) + parser.add_argument( + "--dataset-path", + type=str, + default="data", + help="Path to dataset. Defaults to `data` if not specified.", + ) + + # Parse args and run this script + args = parser.parse_args() + main() diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 00000000..de21122c --- /dev/null +++ b/mypy.ini @@ -0,0 +1,8 @@ +# Mypy is a static type checker for Python 3 and Python 2.7. If you sprinkle your code with type annotations, +# mypy can type check your code and find common bugs. As mypy is a static analyzer, or a lint-like tool, the type +# annotations are just hints for mypy and don’t interfere when running your program. You run your program with a +# standard Python interpreter, and the annotations are treated effectively as comments. +# See https://mypy.readthedocs.io/en/stable/index.html for more information. + +[mypy] +ignore_missing_imports = True diff --git a/requirements.dev.txt b/requirements.dev.txt new file mode 100644 index 00000000..d28b7f55 --- /dev/null +++ b/requirements.dev.txt @@ -0,0 +1,3 @@ +black==20.8b1 +codalab>=0.5.40 +mypy==0.782 From ea1303d9db39f6a85152188a1f813a90da1c547b Mon Sep 17 00:00:00 2001 From: Tony Lee Date: Fri, 19 Mar 2021 06:24:02 -0700 Subject: [PATCH 040/244] cleanup --- examples/evaluate.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/examples/evaluate.py b/examples/evaluate.py index b8f0ff31..a2f08b69 100644 --- a/examples/evaluate.py +++ b/examples/evaluate.py @@ -121,9 +121,6 @@ def get_metrics(dataset_name: str) -> List[str]: replicates_results[split][metric] = [] for replicate in replicates: - # TODO: do I need to set seed here? -Tony - # set_seed_for_dataset(dataset_name, seed=replicate) - predictions_file = get_best_prediction_filename( dataset_name, split, replicate ) From 12da325c4e20633abc6a185a12ce85d5d68b9d68 Mon Sep 17 00:00:00 2001 From: Tony Lee Date: Fri, 19 Mar 2021 06:27:28 -0700 Subject: [PATCH 041/244] Added comments to new dependencies --- requirements.dev.txt | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/requirements.dev.txt b/requirements.dev.txt index d28b7f55..683e5b5a 100644 --- a/requirements.dev.txt +++ b/requirements.dev.txt @@ -1,3 +1,2 @@ -black==20.8b1 -codalab>=0.5.40 -mypy==0.782 +black==20.8b1 # Python code formatter +mypy==0.782 # Python static type checker From e4073ed649054468341faba705835a118397d158 Mon Sep 17 00:00:00 2001 From: Tony Lee Date: Fri, 19 Mar 2021 06:30:16 -0700 Subject: [PATCH 042/244] Cleanup comments --- examples/evaluate.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/evaluate.py b/examples/evaluate.py index a2f08b69..edf0cfc8 100644 --- a/examples/evaluate.py +++ b/examples/evaluate.py @@ -16,12 +16,12 @@ """ -Evaluate models and predictions for WILDS datasets. +Evaluate predictions for WILDS datasets. Usage: - python examples/evaluate.py - python examples/evaluate.py --dataset + python examples/evaluate.py + python examples/evaluate.py --dataset """ @@ -31,7 +31,7 @@ def evaluate_all(path: str, output_path: str, dataset_path: str): Evaluate for all the WILDS datasets. Parameters: - path (str): Path to the directory with models or predictions. Can be a URL + path (str): Path to the directory with predictions. Can be a URL output_path (str): Output directory dataset_path (str): Path to the dataset directory """ @@ -55,7 +55,7 @@ def evaluate_multiple_replicates( Parameters: dataset_name (str): Name of the dataset. See datasets.py for the complete list of datasets. - path (str): Path to the directory with models or predictions. Can be a URL. + path (str): Path to the directory with predictions. Can be a URL. output_path (str): Output directory dataset_path (str): Path to the dataset directory @@ -229,7 +229,7 @@ def main(): if __name__ == "__main__": parser = argparse.ArgumentParser( - description="Evaluate models and predictions for WILDS datasets." + description="Evaluate predictions for WILDS datasets." ) parser.add_argument( "path", From b0f392f3b63f91dc333b4c03531dde876962cfa2 Mon Sep 17 00:00:00 2001 From: Tony Lee Date: Fri, 19 Mar 2021 06:59:21 -0700 Subject: [PATCH 043/244] Get splits from WILDSDataset --- examples/evaluate.py | 29 ++++++++++------------------- 1 file changed, 10 insertions(+), 19 deletions(-) diff --git a/examples/evaluate.py b/examples/evaluate.py index edf0cfc8..647ec16c 100644 --- a/examples/evaluate.py +++ b/examples/evaluate.py @@ -43,7 +43,7 @@ def evaluate_all(path: str, output_path: str, dataset_path: str): # Write out aggregated results to output file print(f"Writing complete results to {output_path}...") - with open(os.path.join(output_path, f"all_results.json"), "w") as f: + with open(os.path.join(output_path, "all_results.json"), "w") as f: json.dump(all_results, f, indent=4) @@ -63,16 +63,6 @@ def evaluate_multiple_replicates( Metrics as a dictionary with metrics as the keys and metric values as the values """ - def get_splits(dataset_name: str) -> List[str]: - if dataset_name in ["amazon", "fmow", "iwildcam", "poverty", "py150"]: - return ["id_val", "id_test", "val", "test"] - elif dataset_name == "camelyon17": - return ["id_val", "val", "test"] - elif dataset_name in ["civilcomments", "ogb-molpcba"]: - return ["val", "test"] - else: - raise ValueError(f"Invalid dataset: {dataset_name}") - def get_replicates(dataset_name: str) -> List[Union[str, int]]: if dataset_name == "camelyon17": return list(range(0, 10)) @@ -109,8 +99,12 @@ def get_metrics(dataset_name: str) -> List[str]: else: raise ValueError(f"Invalid dataset: {dataset_name}") + # Dataset will only be downloaded if it does not exist + wilds_dataset: WILDSDataset = get_dataset( + dataset=dataset_name, root_dir=dataset_path, download=True + ) + splits: List[str] = wilds_dataset.split_dict.keys() replicates_results: Dict[str, Dict[str, List[float]]] = dict() - splits: List[str] = get_splits(dataset_name) replicates: List[Union[str, int]] = get_replicates(dataset_name) metrics: List[str] = get_metrics(dataset_name) @@ -133,7 +127,7 @@ def get_metrics(dataset_name: str) -> List[str]: np.array(predicted_labels) ) metric_results: Dict[str, float] = evaluate( - dataset_name, split, predicted_labels_tensor, dataset_path + wilds_dataset, split, predicted_labels_tensor ) for metric in metrics: replicates_results[split][metric].append(metric_results[metric]) @@ -159,23 +153,20 @@ def get_metrics(dataset_name: str) -> List[str]: def evaluate( - dataset_name: str, split: str, predicted_labels: torch.Tensor, dataset_path: str + dataset: WILDSDataset, split: str, predicted_labels: torch.Tensor ) -> Dict[str, float]: """ Evaluate the given predictions and return the appropriate metrics. Parameters: - dataset_name (str): Name of the dataset. + dataset (WILDSDataset): A WILDS Dataset + split (str): split we are evaluating on predicted_labels (torch.Tensor): Predictions - dataset_path (str): Path to the dataset directory Returns: Metrics as a dictionary with metrics as the keys and metric values as the values """ # Dataset will only be downloaded if it does not exist - dataset: WILDSDataset = get_dataset( - dataset=dataset_name, root_dir=dataset_path, download=True - ) subset: WILDSSubset = dataset.get_subset(split) true_labels: torch.Tensor = subset.y_array metadata: torch.Tensor = subset.metadata_array From 1dfc58f7868a00439a0d9904fe0d3edfbde722b1 Mon Sep 17 00:00:00 2001 From: Tony Lee Date: Fri, 19 Mar 2021 07:15:07 -0700 Subject: [PATCH 044/244] Get rid of train split --- examples/evaluate.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/evaluate.py b/examples/evaluate.py index 647ec16c..6cefda37 100644 --- a/examples/evaluate.py +++ b/examples/evaluate.py @@ -104,6 +104,9 @@ def get_metrics(dataset_name: str) -> List[str]: dataset=dataset_name, root_dir=dataset_path, download=True ) splits: List[str] = wilds_dataset.split_dict.keys() + if "train" in splits: + splits.remove("train") + replicates_results: Dict[str, Dict[str, List[float]]] = dict() replicates: List[Union[str, int]] = get_replicates(dataset_name) metrics: List[str] = get_metrics(dataset_name) From 1a4444d1536cb26053b7d4eb9c344db2c3e68d73 Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Fri, 19 Mar 2021 11:33:00 -0700 Subject: [PATCH 045/244] x,y transform --- examples/configs/datasets.py | 4 ++-- examples/configs/model.py | 3 ++- examples/models/detr/detr.py | 13 ++++++++----- examples/transforms.py | 18 ++++++++++++++---- wilds/common/metrics/all_metrics.py | 17 +++++++---------- wilds/datasets/wilds_dataset.py | 2 +- 6 files changed, 34 insertions(+), 23 deletions(-) diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index a201e4bd..a8102723 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -301,9 +301,9 @@ 'optimizer_kwargs': {}, 'scheduler': None, 'batch_size': 4, - 'lr': 1e-4, + 'lr': 1e-5, 'weight_decay': 1e-4, - 'n_epochs': 10, + 'n_epochs': 50, 'process_outputs_function': 'remove_detr_aux_outputs', 'loader_kwargs': { 'num_workers': 1, diff --git a/examples/configs/model.py b/examples/configs/model.py index d2cb1cd4..79b29841 100644 --- a/examples/configs/model.py +++ b/examples/configs/model.py @@ -63,7 +63,8 @@ 'dice_loss_coef': 1, 'bbox_loss_coef': 5, 'giou_loss_coef': 2, - 'eos_coef': 0.1, + # 'eos_coef': 0.1, + 'eos_coef': 0.5, } } } diff --git a/examples/models/detr/detr.py b/examples/models/detr/detr.py index b734b4a7..061c5fcd 100644 --- a/examples/models/detr/detr.py +++ b/examples/models/detr/detr.py @@ -210,11 +210,12 @@ def forward(self, outputs, targets): indices = self.matcher(outputs_without_aux, targets) # Compute the average number of target boxes accross all nodes, for normalization purposes - num_boxes = sum(len(t["labels"]) for t in targets) - num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) - if is_dist_avail_and_initialized(): - torch.distributed.all_reduce(num_boxes) - num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() + # num_boxes = sum(len(t["labels"]) for t in targets) + # num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) + # if is_dist_avail_and_initialized(): + # torch.distributed.all_reduce(num_boxes) + # num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() + num_boxes = None # Compute all the requested losses total_loss = 0 @@ -239,6 +240,8 @@ def forward(self, outputs, targets): device = outputs['pred_logits'].device elementwise_loss = torch.zeros(len(outputs['pred_logits']), device=device) + # print(f"Losses: class {losses['loss_ce'].detach().cpu().numpy()}, bbox {losses['loss_bbox'].detach().cpu().numpy()}, giou {losses['loss_giou'].detach().cpu().numpy()}") + for k in self.weight_dict: elementwise_loss += self.weight_dict[k] * losses[k] diff --git a/examples/transforms.py b/examples/transforms.py index bafbd42f..bbcd88a4 100644 --- a/examples/transforms.py +++ b/examples/transforms.py @@ -3,6 +3,10 @@ import torch def initialize_transform(transform_name, config, dataset): + """ + Transforms should take in a single (x, y) + and return (transformed_x, transformed_y). + """ if transform_name is None: return None elif transform_name=='bert': @@ -16,6 +20,11 @@ def initialize_transform(transform_name, config, dataset): else: raise ValueError(f"{transform_name} not recognized") +def transform_input_only(input_transform): + def transform(x, y): + return input_transform(x), y + return transform + def initialize_bert_transform(config): assert 'bert' in config.model assert config.max_token_length is not None @@ -41,7 +50,7 @@ def transform(text): dim=2) x = torch.squeeze(x, dim=0) # First shape dim is always 1 return x - return transform + return transform_input_only(transform) def getBertTokenizer(model): if model == 'bert-base-uncased': @@ -65,7 +74,7 @@ def initialize_image_base_transform(config, dataset): transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ] transform = transforms.Compose(transform_steps) - return transform + return transform_input_only(transform) def initialize_image_resize_and_center_crop_transform(config, dataset): """ @@ -84,7 +93,7 @@ def initialize_image_resize_and_center_crop_transform(config, dataset): transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) - return transform + return transform_input_only(transform) def initialize_poverty_train_transform(): transforms_ls = [ @@ -99,5 +108,6 @@ def transform_rgb(img): # bgr to rgb and back to bgr img[:3] = rgb_transform(img[:3][[2,1,0]])[[2,1,0]] return img + transform = transforms.Lambda(lambda x: transform_rgb(x)) - return transform + return transform_input_only(transform) diff --git a/wilds/common/metrics/all_metrics.py b/wilds/common/metrics/all_metrics.py index ddf2baf0..b1d8c021 100644 --- a/wilds/common/metrics/all_metrics.py +++ b/wilds/common/metrics/all_metrics.py @@ -185,7 +185,7 @@ def _compute_element_wise(self, y_pred ,y_true ): batch_results = [] for src_boxes, target_boxes, target_logits in zip( y_true, y_pred['pred_boxes'], y_pred['pred_logits']): - # Here should be prediction_fn ? + # Here should be prediction_fn ? target_scores = F.softmax(target_logits, dim=1)[..., 0] pred_boxes = target_boxes[target_scores > self.score_threshold] @@ -204,7 +204,7 @@ def _accuracy(self, src_boxes,pred_boxes , iou_threshold = 1.): if total_gt > 0 and total_pred > 0: # Define the matcher and distance matrix based on iou - matcher = Matcher(iou_threshold,iou_threshold,allow_low_quality_matches=False) + matcher = Matcher(iou_threshold,iou_threshold,allow_low_quality_matches=False) src_boxes = box_convert(src_boxes , "cxcywh" ,"xyxy") pred_boxes = box_convert(pred_boxes , "cxcywh" ,"xyxy") @@ -213,16 +213,15 @@ def _accuracy(self, src_boxes,pred_boxes , iou_threshold = 1.): match_quality_matrix = box_iou(src_boxes,pred_boxes) results = matcher(match_quality_matrix) - + true_positive = torch.count_nonzero(results.unique() != -1) matched_elements = results[results > -1] - - #in Matcher, a pred element can be matched only twice + + #in Matcher, a pred element can be matched only twice false_positive = torch.count_nonzero(results == -1) + ( len(matched_elements) - len(matched_elements.unique())) false_negative = total_gt - true_positive - - return true_positive / ( true_positive + false_positive + false_negative ) + return true_positive / ( true_positive + false_positive + false_negative ) elif total_gt == 0: if total_pred > 0: @@ -231,10 +230,8 @@ def _accuracy(self, src_boxes,pred_boxes , iou_threshold = 1.): return torch.tensor(1.) elif total_gt > 0 and total_pred == 0: return torch.tensor(0.) - + def worst(self, metrics): return minimum(metrics) - - diff --git a/wilds/datasets/wilds_dataset.py b/wilds/datasets/wilds_dataset.py index 8bf0128d..8812f957 100644 --- a/wilds/datasets/wilds_dataset.py +++ b/wilds/datasets/wilds_dataset.py @@ -453,7 +453,7 @@ def __init__(self, dataset, indices, transform): def __getitem__(self, idx): x, y, metadata = self.dataset[self.indices[idx]] if self.transform is not None: - x = self.transform(x) + x, y = self.transform(x, y) return x, y, metadata def __len__(self): From c14746015831bbb47964f4bbf88123fc810c602d Mon Sep 17 00:00:00 2001 From: aikanor Date: Fri, 19 Mar 2021 20:17:48 -0700 Subject: [PATCH 046/244] integration besides eval 2/ --- examples/configs/supported.py | 4 +- examples/sbox_run_expt.ipynb | 494 +++++++++++----------------- wilds/common/metrics/all_metrics.py | 30 ++ 3 files changed, 223 insertions(+), 305 deletions(-) diff --git a/examples/configs/supported.py b/examples/configs/supported.py index bf7f73cc..c2d0fe5c 100644 --- a/examples/configs/supported.py +++ b/examples/configs/supported.py @@ -16,7 +16,7 @@ from wilds.datasets.yelp_dataset import YelpDataset # metrics from wilds.common.metrics.loss import ElementwiseLoss, Loss, MultiTaskLoss -from wilds.common.metrics.all_metrics import Accuracy, MultiTaskAccuracy, MSE, MultiTaskAveragePrecision +from wilds.common.metrics.all_metrics import Accuracy, MultiTaskAccuracy, MSE, MultiTaskAveragePrecision, MTAveragePrecision datasets = { 'amazon': AmazonDataset, @@ -43,7 +43,7 @@ 'accuracy': Accuracy(), 'mse': MSE(), 'multitask_accuracy': MultiTaskAccuracy(), - 'multitask_avgprec': MultiTaskAveragePrecision(), + 'multitask_avgprec': MTAveragePrecision(), None: None, } diff --git a/examples/sbox_run_expt.ipynb b/examples/sbox_run_expt.ipynb index 86525331..2aad102b 100644 --- a/examples/sbox_run_expt.ipynb +++ b/examples/sbox_run_expt.ipynb @@ -11,21 +11,14 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "396.69921875\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:root:The OGB package is out of date. Your version is 1.2.4, while the latest version is 1.3.0.\n" + "47.42578125\n" ] } ], @@ -113,7 +106,7 @@ { "data": { "text/plain": [ - "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" + "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" ] }, "execution_count": 2, @@ -229,7 +222,7 @@ "\n", "config = config_camelyon\n", "config = config_encode\n", - "config = config_bdd100k\n" + "# config = config_bdd100k\n" ] }, { @@ -262,20 +255,9 @@ "cell_type": "code", "execution_count": 5, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Namespace(algo_log_metric='multitask_accuracy', algorithm='ERM', batch_size=32, coral_penalty_weight=None, dataset='bdd100k', dataset_kwargs={}, device=0, distinct_groups=None, download=False, eval_epoch=None, eval_loader='standard', eval_only=False, eval_splits=[], eval_transform='image_base', evaluate_all_splits=True, frac=1.0, group_dro_step_size=None, groupby_fields=None, irm_lambda=None, irm_penalty_anneal_iters=None, loader_kwargs={'num_workers': 1, 'pin_memory': True}, log_dir='./logs', log_every=50, loss_function='multitask_bce', lr=0.001, max_grad_norm=None, max_token_length=None, model='resnet50', model_kwargs={'pretrained': False}, n_epochs=10, n_groups_per_batch=4, no_group_logging=True, optimizer='SGD', optimizer_kwargs={'momentum': 0.9}, progress_bar=False, resize_scale=None, resume=False, root_dir='data', save_best=True, save_last=True, save_step=None, scheduler=None, scheduler_kwargs={}, scheduler_metric_name=None, scheduler_metric_split='val', seed=0, split_scheme='official', target_resolution=(224, 224), train_loader='standard', train_transform='image_base', uniform_over_groups=False, use_wandb=False, val_metric='acc_all', val_metric_decreasing=False, weight_decay=0.0001)" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "config#.optimizer_kwargs = {}" + "config.optimizer_kwargs = {}" ] }, { @@ -287,7 +269,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Dataset: bdd100k\n", + "Dataset: encode-tfbs\n", "Algorithm: ERM\n", "Root dir: data\n", "Split scheme: official\n", @@ -298,31 +280,31 @@ "Train loader: standard\n", "Uniform over groups: False\n", "Distinct groups: None\n", - "N groups per batch: 4\n", - "Batch size: 32\n", + "N groups per batch: 2\n", + "Batch size: 64\n", "Eval loader: standard\n", - "Model: resnet50\n", + "Model: leopard\n", "Model kwargs: {'pretrained': False}\n", - "Train transform: image_base\n", - "Eval transform: image_base\n", - "Target resolution: (224, 224)\n", + "Train transform: None\n", + "Eval transform: None\n", + "Target resolution: None\n", "Resize scale: None\n", "Max token length: None\n", "Loss function: multitask_bce\n", - "Groupby fields: None\n", + "Groupby fields: ['celltype']\n", "Group dro step size: None\n", "Coral penalty weight: None\n", "Irm lambda: None\n", "Irm penalty anneal iters: None\n", - "Algo log metric: multitask_accuracy\n", - "Val metric: acc_all\n", + "Algo log metric: multitask_avgprec\n", + "Val metric: acc_avg\n", "Val metric decreasing: False\n", - "N epochs: 10\n", - "Optimizer: SGD\n", + "N epochs: 5\n", + "Optimizer: Adam\n", "Lr: 0.001\n", - "Weight decay: 0.0001\n", + "Weight decay: 0.01\n", "Max grad norm: None\n", - "Optimizer kwargs: {'momentum': 0.9}\n", + "Optimizer kwargs: {}\n", "Scheduler: None\n", "Scheduler kwargs: {}\n", "Scheduler metric split: val\n", @@ -338,11 +320,14 @@ "Save step: None\n", "Save best: True\n", "Save last: True\n", - "No group logging: True\n", + "No group logging: False\n", "Use wandb: False\n", "Progress bar: False\n", "Resume: False\n", - "\n" + "\n", + "chr3 2.979121685028076\n", + "chr2 6.626891374588013\n", + "chr1 10.355815410614014\n" ] } ], @@ -612,252 +597,6 @@ " return mdf, y_label_vec" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Dataset object (long version)" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": { - "jupyter": { - "source_hidden": true - } - }, - "outputs": [], - "source": [ - "import os, time\n", - "import torch\n", - "import pandas as pd\n", - "import numpy as np\n", - "from wilds.datasets.wilds_dataset import WILDSDataset\n", - "from wilds.common.grouper import CombinatorialGrouper\n", - "from wilds.common.metrics.all_metrics import Accuracy\n", - "\n", - "class EncodeTFBSDataset(WILDSDataset):\n", - " \"\"\"\n", - " ENCODE-DREAM-wilds dataset of transcription factor binding sites. \n", - " This is a subset of the dataset from the ENCODE-DREAM in vivo Transcription Factor Binding Site Prediction Challenge. \n", - " \n", - " Input (x):\n", - " 1000-base-pair regions of sequence with a quantified chromatin accessibility readout.\n", - "\n", - " Label (y):\n", - " y is binary. It is 1 if the central 200bp region is bound by the transcription factor MAX, and 0 otherwise.\n", - "\n", - " Metadata:\n", - " Each sequence is annotated with the celltype of origin (a string) and the chromosome of origin (a string).\n", - " \n", - " Website:\n", - " https://www.synapse.org/#!Synapse:syn6131484\n", - " \"\"\"\n", - "\n", - " def __init__(self, root_dir='data', download=False, split_scheme='official'):\n", - " itime = time.time()\n", - " self._dataset_name = 'encode-tfbs'\n", - " self._version = '1.0'\n", - " self._download_url = 'https://worksheets.codalab.org/rest/bundles/0x8b3255e21e164cd98d3aeec09cd0bc26/contents/blob/'\n", - " self._data_dir = self.initialize_data_dir(root_dir, download)\n", - " self._y_size = 128\n", - " # self._n_classes = 2\n", - " \n", - " self._train_chroms = ['chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr10', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX']\n", - " self._val_chroms = ['chr2', 'chr9', 'chr11']\n", - " self._test_chroms = ['chr1', 'chr8', 'chr21']\n", - " self._transcription_factor = 'MAX'\n", - " self._train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']\n", - " self._val_celltype = ['A549']\n", - " self._test_celltype = ['GM12878']\n", - " self._all_chroms = self._train_chroms + self._val_chroms + self._test_chroms\n", - " self._all_celltypes = self._train_celltypes + self._val_celltype + self._test_celltype\n", - " \n", - " self._metadata_map = {}\n", - " self._metadata_map['chr'] = self._all_chroms\n", - " self._metadata_map['celltype'] = self._all_celltypes\n", - " \n", - " # Get the splits\n", - " if split_scheme=='official':\n", - " split_scheme = 'standard'\n", - " \n", - " self._split_scheme = split_scheme\n", - " self._split_dict = {\n", - " 'train': 0,\n", - " 'id_val': 1,\n", - " 'test': 2,\n", - " 'val': 3\n", - " }\n", - " self._split_names = {\n", - " 'train': 'Train',\n", - " 'id_val': 'Validation (ID)',\n", - " 'test': 'Test',\n", - " 'val': 'Validation (OOD)',\n", - " }\n", - " \n", - " # Load sequence and DNase features\n", - " sequence_filename = os.path.join(self._data_dir, 'sequence.npz')\n", - " seq_arr = np.load(sequence_filename)\n", - " self._seq_bp = {}\n", - " for chrom in self._all_chroms: #seq_arr:\n", - " self._seq_bp[chrom] = seq_arr[chrom]\n", - " print(chrom, time.time() - itime)\n", - " \n", - " self._dnase_allcelltypes = {}\n", - " ct = 'avg'\n", - " dnase_avg_bw_path = os.path.join(self._data_dir, 'Leopard_dnase/{}.bigwig'.format(ct))\n", - " self._dnase_allcelltypes[ct] = pyBigWig.open(dnase_avg_bw_path)\n", - " for ct in self._all_celltypes:\n", - " \"\"\"\n", - " dnase_filename = os.path.join(self._data_dir, '{}_dnase.npz'.format(ct))\n", - " dnase_npz_contents = np.load(dnase_filename)\n", - " self._dnase_allcelltypes[ct] = {}\n", - " for chrom in self._all_chroms: #self._seq_bp:\n", - " self._dnase_allcelltypes[ct][chrom] = dnase_npz_contents[chrom]\n", - " \"\"\"\n", - " dnase_bw_path = os.path.join(self._data_dir, 'Leopard_dnase/{}.bigwig'.format(ct))\n", - " self._dnase_allcelltypes[ct] = pyBigWig.open(dnase_bw_path)\n", - " \n", - " self._metadata_df = pd.read_csv(\n", - " self._data_dir + '/labels/MAX/metadata_df.bed', sep='\\t', header=None, \n", - " index_col=None, names=['chr', 'start', 'stop', 'celltype']\n", - " )\n", - " \n", - " train_regions_mask = np.isin(self._metadata_df['chr'], self._train_chroms)\n", - " val_regions_mask = np.isin(self._metadata_df['chr'], self._val_chroms)\n", - " test_regions_mask = np.isin(self._metadata_df['chr'], self._test_chroms)\n", - " train_celltype_mask = np.isin(self._metadata_df['celltype'], self._train_celltypes)\n", - " val_celltype_mask = np.isin(self._metadata_df['celltype'], self._val_celltype)\n", - " test_celltype_mask = np.isin(self._metadata_df['celltype'], self._test_celltype)\n", - " \n", - " split_array = -1*np.ones(self._metadata_df.shape[0]).astype(int)\n", - " split_array[np.logical_and(train_regions_mask, train_celltype_mask)] = self._split_dict['train']\n", - " split_array[np.logical_and(test_regions_mask, test_celltype_mask)] = self._split_dict['test']\n", - " # Validate using validation chr, either using a designated validation cell line ('val') or a training cell line ('id_val')\n", - " split_array[np.logical_and(val_regions_mask, val_celltype_mask)] = self._split_dict['val']\n", - " split_array[np.logical_and(val_regions_mask, train_celltype_mask)] = self._split_dict['id_val']\n", - " \n", - " if self._split_scheme=='standard':\n", - " self._metadata_df.insert(len(self._metadata_df.columns), 'split', split_array)\n", - " else:\n", - " raise ValueError(f'Split scheme {self._split_scheme} not recognized')\n", - " \n", - " metadata_mask = (self._metadata_df['split'] != -1)\n", - " self._metadata_df = self._metadata_df[self._metadata_df['split'] != -1]\n", - " \n", - " chr_ints = self._metadata_df['chr'].replace(dict( [(y, x) for x, y in enumerate(self._metadata_map['chr'])] )).values\n", - " celltype_ints = self._metadata_df['celltype'].replace(dict( [(y, x) for x, y in enumerate(self._metadata_map['celltype'])] )).values\n", - " self._split_array = self._metadata_df['split'].values\n", - " self._y_array = torch.Tensor(np.load(self._data_dir + '/labels/MAX/metadata_y.npy'))\n", - " self._y_array = self._y_array[metadata_mask]\n", - " \n", - " self._metadata_array = torch.stack(\n", - " (torch.LongTensor(chr_ints), \n", - " torch.LongTensor(celltype_ints)\n", - " ),\n", - " dim=1)\n", - " self._metadata_fields = ['chr', 'celltype']\n", - " \n", - " self._eval_grouper = CombinatorialGrouper(\n", - " dataset=self,\n", - " groupby_fields=['celltype'])\n", - " \n", - " self._metric = Accuracy()\n", - " \n", - " super().__init__(root_dir, download, split_scheme)\n", - " \n", - " \"\"\"\n", - " def get_random_label_vec(metadata_df, output_size=128):\n", - " # Sample a positively labeled region at random\n", - " pos_mdf = metadata_df[metadata_df['y'] == 1] #.iloc[ metadata_df['chr'] == s['chr'], : ]\n", - " pos_seed_region = pos_mdf.iloc[np.random.randint(pos_mdf.shape[0])]\n", - "\n", - " # Extract regions from this chromosome in this celltype, to get a window of labels from\n", - " chr_msk = np.array(metadata_df['chr']) == pos_seed_region['chr']\n", - " ct_msk = np.array(metadata_df['celltype']) == pos_seed_region['celltype']\n", - " mdf = metadata_df[chr_msk & ct_msk]\n", - "\n", - " # Get labels\n", - " start_ndx = np.where(mdf['start'] == pos_seed_region['start'])[0][0]\n", - " y_label_vec = mdf.iloc[start_ndx:start_ndx+output_size, :]['y']\n", - " \"\"\"\n", - " \n", - " def get_input(self, idx, window_size=12800):\n", - " \"\"\"\n", - " Returns x for a given idx in metadata_array, which has been filtered to only take windows with the desired stride.\n", - " Computes this from: \n", - " (1) sequence features in self._seq_bp\n", - " (2) DNase bigwig file handles in self._dnase_allcelltypes\n", - " (3) Metadata for the index (location along the genome with 6400bp window width)\n", - " (4) Window_size, the length of sequence returned (centered on the 6400bp region in (3))\n", - " \"\"\"\n", - " this_metadata = self._metadata_df.iloc[idx, :]\n", - " interval_start = this_metadata['start'] - int(window_size/4)\n", - " interval_end = interval_start + window_size #this_metadata['stop']\n", - " seq_this = self._seq_bp[this_metadata['chr']][interval_start:interval_end]\n", - " dnase_bw = self._dnase_allcelltypes[this_metadata['celltype']]\n", - " dnase_this = dnase_bw.values(chrom, interval_start, interval_end, numpy=True)\n", - " dnase_avg = self._dnase_allcelltypes['avg'].values(chrom, interval_start, interval_end, numpy=True)\n", - " return torch.tensor(np.column_stack(\n", - " [np.nan_to_num(seq_this), np.nan_to_num(dnase_this), np.nan_to_num(dnase_avg)]\n", - " ))\n", - "\n", - " def eval(self, y_pred, y_true, metadata):\n", - " return self.standard_group_eval(\n", - " self._metric,\n", - " self._eval_grouper,\n", - " y_pred, y_true, metadata)" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "metadata": { - "collapsed": true, - "jupyter": { - "outputs_hidden": true - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "chr3 3.0425407886505127\n", - "chr4 5.967821359634399\n", - "chr5 8.747126340866089\n", - "chr6 11.370141744613647\n", - "chr7 13.802208423614502\n", - "chr10 15.875979900360107\n", - "chr12 17.929850339889526\n", - "chr13 19.67976665496826\n", - "chr14 21.306750059127808\n", - "chr15 22.866544723510742\n", - "chr16 24.241100788116455\n", - "chr17 25.480982303619385\n", - "chr18 26.677065134048462\n", - "chr19 27.579110622406006\n", - "chr20 28.545915603637695\n", - "chr22 29.323810577392578\n", - "chrX 31.698036670684814\n", - "chr2 35.40705943107605\n", - "chr9 37.5518524646759\n", - "chr11 39.61783218383789\n", - "chr1 43.411964893341064\n", - "chr8 45.64823389053345\n", - "chr21 46.377281188964844\n" - ] - } - ], - "source": [ - "full_dataset_encode = EncodeTFBSDataset(\n", - " root_dir=config.root_dir,\n", - " download=config.download,\n", - " split_scheme=config.split_scheme,\n", - " **config.dataset_kwargs)" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -875,12 +614,38 @@ "output_type": "stream", "text": [ "Train data...\n", - " n = 64993\n", - "Validation data...\n", - " n = 4860\n", + " celltype = H1-hESC: n = 5314\n", + " celltype = HCT116: n = 4759\n", + " celltype = HeLa-S3: n = 4635\n", + " celltype = HepG2: n = 4459\n", + " celltype = K562: n = 5169\n", + " celltype = A549: n = 0\n", + " celltype = GM12878: n = 0\n", + "Validation (ID) data...\n", + " celltype = H1-hESC: n = 6872\n", + " celltype = HCT116: n = 6315\n", + " celltype = HeLa-S3: n = 4219\n", + " celltype = HepG2: n = 8356\n", + " celltype = K562: n = 6538\n", + " celltype = A549: n = 0\n", + " celltype = GM12878: n = 0\n", "Test data...\n", - " n = 4742\n", - "Dout: 9\n" + " celltype = H1-hESC: n = 0\n", + " celltype = HCT116: n = 0\n", + " celltype = HeLa-S3: n = 0\n", + " celltype = HepG2: n = 0\n", + " celltype = K562: n = 0\n", + " celltype = A549: n = 0\n", + " celltype = GM12878: n = 4487\n", + "Validation (OOD) data...\n", + " celltype = H1-hESC: n = 0\n", + " celltype = HCT116: n = 0\n", + " celltype = HeLa-S3: n = 0\n", + " celltype = HepG2: n = 0\n", + " celltype = K562: n = 0\n", + " celltype = A549: n = 6728\n", + " celltype = GM12878: n = 0\n", + "Dout: 128\n" ] } ], @@ -978,7 +743,7 @@ { "data": { "text/plain": [ - "tensor(0.8208, device='cuda:0', grad_fn=)" + "tensor(0.7212, device='cuda:0', grad_fn=)" ] }, "execution_count": 9, @@ -1001,7 +766,7 @@ { "data": { "text/plain": [ - "" + "torch.Size([64, 128])" ] }, "execution_count": 10, @@ -1011,7 +776,7 @@ ], "source": [ "#np.unique(full_dataset._metadata_df['split'], return_counts=True)\n", - "full_dataset" + "y_true.squeeze().shape" ] }, { @@ -1036,6 +801,26 @@ "importlib.reload(train)" ] }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Namespace(algo_log_metric='multitask_avgprec', algorithm='ERM', batch_size=64, coral_penalty_weight=None, dataset='encode-tfbs', dataset_kwargs={}, device=device(type='cuda', index=0), distinct_groups=None, download=False, eval_epoch=None, eval_loader='standard', eval_only=False, eval_splits=[], eval_transform=None, evaluate_all_splits=True, frac=1.0, group_dro_step_size=None, groupby_fields=['celltype'], irm_lambda=None, irm_penalty_anneal_iters=None, loader_kwargs={'num_workers': 1, 'pin_memory': True}, log_dir='./logs', log_every=50, loss_function='multitask_bce', lr=0.001, max_grad_norm=None, max_token_length=None, model='leopard', model_kwargs={'pretrained': False}, n_epochs=5, n_groups_per_batch=2, no_group_logging=False, optimizer='Adam', optimizer_kwargs={}, progress_bar=False, resize_scale=None, resume=False, root_dir='data', save_best=True, save_last=True, save_step=None, scheduler=None, scheduler_kwargs={}, scheduler_metric_name=None, scheduler_metric_split='val', seed=0, split_scheme='official', target_resolution=None, train_loader='standard', train_transform=None, uniform_over_groups=False, use_wandb=False, val_metric='acc_avg', val_metric_decreasing=False, weight_decay=0.01)" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "config" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -1055,25 +840,128 @@ "\n", "Epoch [0]:\n", "\n", - "Train:\n" + "Train:\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (2432,) [1 0 1 ... 1 1 0] (2432,) 0.09923777357272781 tensor(0.0992, dtype=torch.float64)\n", + "why [False False False ... False False False] (1792,) [1 1 0 ... 1 0 1] (1792,) 0.18020602071676678 tensor(0.1802, dtype=torch.float64)\n", + "why [False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False True\n", + " True True True True True True True True True True True False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False True\n", + " True True True True True True True True True True True False\n", + " False True True True True True True True True True True True\n", + " True True True True True False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False True True True True True True True True True\n", + " True True True False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False True True\n", + " True True True True True True True True True True False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False True True True True True True True True True\n", + " True True True False False False False False False False False False\n", + " False False False False True True True True True True True True\n", + " True True True False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False True True True True True True True\n", + " True True True True True False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False] (896,) [1 0 1 0 0 1 1 0 1 0 0 1 0 1 1 1 0 1 1 0 1 0 1 1 0 0 1 1 1 1 1 1 0 1 0 0 0\n", + " 1 0 1 1 0 0 0 1 1 1 0 1 0 1 0 0 0 1 0 0 0 0 0 0 0 0 0 1 0 0 1 0 0 1 0 0 0\n", + " 0 1 1 1 0 1 1 0 1 0 1 0 0 1 0 0 1 1 0 0 0 0 1 1 0 0 0 0 1 0 0 1 0 0 0 0 0\n", + " 1 1 0 0 1 1 0 1 0 1 0 0 1 0 1 0 1 1 0 0 0 1 1 1 1 1 0 1 0 0 0 1 0 0 1 1 0\n", + " 1 0 0 1 0 0 0 0 1 1 1 1 0 0 0 0 0 1 0 1 0 1 1 1 1 0 1 1 0 0 1 1 1 1 1 1 0\n", + " 0 0 1 1 1 1 1 1 1 1 0 0 1 0 1 0 0 0 1 1 1 0 1 1 1 1 1 1 0 1 0 1 1 0 1 1 0\n", + " 1 0 1 1 1 1 1 0 0 1 1 1 0 0 1 1 0 0 1 0 1 0 0 1 1 1 1 0 1 1 1 1 1 0 1 1 1\n", + " 1 0 0 0 1 0 0 1 0 1 1 1 0 1 1 1 0 0 1 1 0 0 0 1 0 0 1 1 0 1 0 0 0 1 0 0 0\n", + " 1 1 0 1 0 1 1 0 0 1 0 1 0 1 1 1 1 0 1 0 1 0 1 0 0 1 1 0 1 1 0 1 1 1 1 0 0\n", + " 1 1 0 1 0 0 1 1 1 0 0 0 1 0 0 0 0 0 1 0 0 1 1 0 0 0 0 0 1 1 0 1 1 1 1 0 1\n", + " 0 0 0 1 1 0 0 0 0 1 1 1 0 1 1 1 0 1 0 1 0 0 0 0 0 1 1 1 1 1 0 1 0 1 0 0 1\n", + " 1 0 0 1 1 0 0 0 1 1 1 1 0 1 1 0 1 1 1 0 0 1 0 1 0 0 1 1 0 0 0 0 0 0 1 0 0\n", + " 0 0 1 0 1 0 1 0 1 1 1 1 1 0 1 1 1 0 1 0 1 1 0 0 0 0 1 0 1 1 0 1 0 1 1 1 0\n", + " 0 1 0 1 1 1 0 0 1 0 0 1 0 1 1 0 1 0 1 0 0 1 0 1 1 0 1 0 1 0 0 1 1 1 0 1 0\n", + " 1 0 1 1 1 1 1 0 0 1 0 1 0 0 0 0 1 0 0 0 0 0 1 0 0 0 1 0 1 1 0 0 0 0 0 1 0\n", + " 1 1 0 0 0 0 1 1 0 1 0 0 0 0 1 0 1 1 0 1 1 1 0 1 1 0 0 0 0 1 1 1 0 1 0 1 1\n", + " 0 0 0 0 1 1 1 1 1 0 0 1 0 1 0 1 1 0 1 1 1 1 1 1 1 1 0 0 1 1 1 1 1 1 1 1 1\n", + " 1 0 1 1 1 0 1 0 0 0 0 1 0 0 1 1 1 0 1 1 1 1 0 0 0 0 0 1 1 1 0 0 1 0 1 0 0\n", + " 1 0 0 1 0 1 0 1 1 1 0 1 1 0 1 1 0 0 1 0 0 1 1 1 1 0 1 1 0 1 1 1 0 1 1 0 1\n", + " 0 0 1 1 1 1 0 1 0 0 1 1 1 0 1 1 1 1 0 0 1 0 1 0 0 0 1 1 0 1 0 0 1 0 1 0 0\n", + " 1 0 1 1 0 1 1 1 1 0 0 1 0 0 1 1 1 1 0 0 1 1 1 1 0 0 1 0 0 1 0 1 0 1 1 1 0\n", + " 1 0 1 1 0 0 0 0 1 0 0 0 1 1 1 1 0 0 0 1 1 0 1 1 0 0 1 0 0 1 1 1 1 0 0 0 1\n", + " 0 1 0 1 1 0 1 0 0 0 1 1 1 1 1 1 0 1 0 1 1 0 1 0 0 1 1 1 1 0 1 1 0 1 1 1 1\n", + " 1 1 0 0 0 0 0 1 0 1 0 0 0 1 0 1 0 0 1 0 1 1 0 1 0 1 1 1 0 1 1 1 1 0 0 1 1\n", + " 1 1 1 0 1 1 0 1] (896,) 0.12653340353855683 tensor(0.1265, dtype=torch.float64)\n", + "why [False False False ... False False False] (1152,) [0 0 0 ... 1 1 0] (1152,) 0.15009138463477656 tensor(0.1501, dtype=torch.float64)\n", + "why [ True True True ... True True True] (1920,) [0 0 1 ... 1 0 0] (1920,) 0.13893378955027236 tensor(0.1389, dtype=torch.float64)\n" ] }, { - "ename": "KeyboardInterrupt", - "evalue": "", + "ename": "RuntimeError", + "evalue": "All input tensors must be on the same device. Received cpu and cuda:0", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0mbest_val_metric\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 25\u001b[0;31m train(\n\u001b[0m\u001b[1;32m 26\u001b[0m \u001b[0malgorithm\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0malgorithm\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0mdatasets\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdatasets\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/wilds/examples/train.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(algorithm, datasets, general_logger, config, epoch_offset, best_val_metric)\u001b[0m\n\u001b[1;32m 88\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 89\u001b[0m \u001b[0;31m# First run training\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 90\u001b[0;31m \u001b[0mrun_epoch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0malgorithm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdatasets\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'train'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgeneral_logger\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepoch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 91\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[0;31m# Then run val\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/wilds/examples/train.py\u001b[0m in \u001b[0;36mrun_epoch\u001b[0;34m(algorithm, dataset, general_logger, epoch, config, train)\u001b[0m\n\u001b[1;32m 41\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0miterator\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 43\u001b[0;31m \u001b[0mbatch_results\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0malgorithm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 44\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[0mbatch_results\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0malgorithm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mevaluate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/wilds/examples/algorithms/single_model_algorithm.py\u001b[0m in \u001b[0;36mupdate\u001b[0;34m(self, batch)\u001b[0m\n\u001b[1;32m 103\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_update\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 104\u001b[0m \u001b[0;31m# log results\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 105\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate_log\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 106\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msanitize_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 107\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/wilds/examples/algorithms/group_algorithm.py\u001b[0m in \u001b[0;36mupdate_log\u001b[0;34m(self, results)\u001b[0m\n\u001b[1;32m 54\u001b[0m return_dict=False)\n\u001b[1;32m 55\u001b[0m \u001b[0mbatch_log\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34mf'{self.group_prefix}{m.name}'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgroup_metrics\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 56\u001b[0;31m batch_log[m.agg_metric_field] = m.compute(\n\u001b[0m\u001b[1;32m 57\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'y_pred'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'y_true'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/wilds/wilds/common/metrics/metric.py\u001b[0m in \u001b[0;36mcompute\u001b[0;34m(self, y_pred, y_true, return_dict)\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[0magg_metric\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0.\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0my_true\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 87\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 88\u001b[0;31m \u001b[0magg_metric\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compute\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 89\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mreturn_dict\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 90\u001b[0m results = {\n", - "\u001b[0;32m~/wilds/wilds/common/metrics/metric.py\u001b[0m in \u001b[0;36m_compute\u001b[0;34m(self, y_pred, y_true)\u001b[0m\n\u001b[1;32m 224\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 225\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_compute\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 226\u001b[0;31m \u001b[0mflattened_metrics\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompute_flattened\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreturn_dict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 227\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mflattened_metrics\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 228\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0.\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0my_true\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/wilds/wilds/common/metrics/metric.py\u001b[0m in \u001b[0;36mcompute_flattened\u001b[0;34m(self, y_pred, y_true, return_dict)\u001b[0m\n\u001b[1;32m 240\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcompute_flattened\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreturn_dict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 241\u001b[0m \u001b[0mis_labeled\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m~\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misnan\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_true\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 242\u001b[0;31m \u001b[0mbatch_idx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwhere\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mis_labeled\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 243\u001b[0m \u001b[0mflattened_y_pred\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0my_pred\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mis_labeled\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 244\u001b[0m \u001b[0mflattened_y_true\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mis_labeled\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + "\u001b[0;32m~/wilds/examples/algorithms/group_algorithm.py\u001b[0m in \u001b[0;36mupdate_log\u001b[0;34m(self, results)\u001b[0m\n\u001b[1;32m 47\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mm\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlogged_metrics\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mno_group_logging\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 49\u001b[0;31m group_metrics, group_counts, worst_group_metric = m.compute_group_wise(\n\u001b[0m\u001b[1;32m 50\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'y_pred'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'y_true'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/wilds/wilds/common/metrics/metric.py\u001b[0m in \u001b[0;36mcompute_group_wise\u001b[0;34m(self, y_pred, y_true, g, n_groups, return_dict)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mresults\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mdict\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mDictionary\u001b[0m \u001b[0mof\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 113\u001b[0m \"\"\"\n\u001b[0;32m--> 114\u001b[0;31m \u001b[0mgroup_metrics\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgroup_counts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mworst_group_metric\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compute_group_wise\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_groups\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 115\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mreturn_dict\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/wilds/wilds/common/metrics/metric.py\u001b[0m in \u001b[0;36m_compute_group_wise\u001b[0;34m(self, y_pred, y_true, g, n_groups)\u001b[0m\n\u001b[1;32m 134\u001b[0m \u001b[0my_pred\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mg\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mgroup_idx\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 135\u001b[0m y_true[g == group_idx]))\n\u001b[0;32m--> 136\u001b[0;31m \u001b[0mgroup_metrics\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstack\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgroup_metrics\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 137\u001b[0m \u001b[0mworst_group_metric\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mworst\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgroup_metrics\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mgroup_counts\u001b[0m\u001b[0;34m>\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 138\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mRuntimeError\u001b[0m: All input tensors must be on the same device. Received cpu and cuda:0" ] } ], diff --git a/wilds/common/metrics/all_metrics.py b/wilds/common/metrics/all_metrics.py index 9a0d4de1..73ef7b04 100644 --- a/wilds/common/metrics/all_metrics.py +++ b/wilds/common/metrics/all_metrics.py @@ -84,6 +84,9 @@ def _compute_flattened(self, flattened_y_pred, flattened_y_true): to_ret = torch.tensor(score).to(flattened_y_pred.device) print("why ", ytr, ytr.shape, ypr, ypr.shape, score, to_ret) return to_ret + + def _compute(self, y_pred, y_true): + return self._compute_flattened(y_pred, y_true) def worst(self, metrics): return minimum(metrics) @@ -130,6 +133,33 @@ def _compute(self, y_pred, y_true): def worst(self, metrics): return minimum(metrics) +class MTAveragePrecision(Metric): + def __init__(self, prediction_fn=logits_to_binary_pred, name=None, average='macro'): + self.prediction_fn = prediction_fn + if name is None: + name = f'avgprec' + if average is not None: + name+=f'-{average}' + self.average = average + super().__init__(name=name) + + def _compute(self, y_pred, y_true): + if self.prediction_fn is not None: + y_pred = self.prediction_fn(y_pred) + ytr = np.array(torch.flatten(y_true.squeeze()).detach().cpu().numpy() > 0) + ypr = torch.flatten(y_pred.squeeze()).detach().cpu().numpy() + score = sklearn.metrics.average_precision_score( + ytr, + ypr, + average=self.average + ) + to_ret = torch.tensor(score)#.to(flattened_y_pred.device) + print("why ", ytr, ytr.shape, ypr, ypr.shape, score, to_ret) + return to_ret + + def worst(self, metrics): + return minimum(metrics) + class F1(Metric): def __init__(self, prediction_fn=logits_to_pred, name=None, average='binary'): self.prediction_fn = prediction_fn From 027e7fc553d9a24e7dd6332c155e3e5f8d77c6e9 Mon Sep 17 00:00:00 2001 From: aikanor Date: Fri, 19 Mar 2021 20:34:39 -0700 Subject: [PATCH 047/244] integration besides eval 3/ --- examples/sbox_run_expt.ipynb | 1117 ++++++++++++++++++++++++++- wilds/common/metrics/all_metrics.py | 3 +- 2 files changed, 1097 insertions(+), 23 deletions(-) diff --git a/examples/sbox_run_expt.ipynb b/examples/sbox_run_expt.ipynb index 2aad102b..4eeaee7a 100644 --- a/examples/sbox_run_expt.ipynb +++ b/examples/sbox_run_expt.ipynb @@ -106,7 +106,7 @@ { "data": { "text/plain": [ - "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" + "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" ] }, "execution_count": 2, @@ -325,9 +325,9 @@ "Progress bar: False\n", "Resume: False\n", "\n", - "chr3 2.979121685028076\n", - "chr2 6.626891374588013\n", - "chr1 10.355815410614014\n" + "chr3 3.016324281692505\n", + "chr2 6.676640510559082\n", + "chr1 10.41373872756958\n" ] } ], @@ -766,7 +766,13 @@ { "data": { "text/plain": [ - "torch.Size([64, 128])" + "array([[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", + " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", + " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", + " ...,\n", + " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", + " [0. , 0. , 0. , ..., 0.5, 0.5, 0.5],\n", + " [0. , 0. , 0. , ..., 0.5, 0.5, 1. ]], dtype=float32)" ] }, "execution_count": 10, @@ -776,7 +782,7 @@ ], "source": [ "#np.unique(full_dataset._metadata_df['split'], return_counts=True)\n", - "y_true.squeeze().shape" + "y_true.squeeze().detach().numpy()" ] }, { @@ -809,7 +815,7 @@ { "data": { "text/plain": [ - "Namespace(algo_log_metric='multitask_avgprec', algorithm='ERM', batch_size=64, coral_penalty_weight=None, dataset='encode-tfbs', dataset_kwargs={}, device=device(type='cuda', index=0), distinct_groups=None, download=False, eval_epoch=None, eval_loader='standard', eval_only=False, eval_splits=[], eval_transform=None, evaluate_all_splits=True, frac=1.0, group_dro_step_size=None, groupby_fields=['celltype'], irm_lambda=None, irm_penalty_anneal_iters=None, loader_kwargs={'num_workers': 1, 'pin_memory': True}, log_dir='./logs', log_every=50, loss_function='multitask_bce', lr=0.001, max_grad_norm=None, max_token_length=None, model='leopard', model_kwargs={'pretrained': False}, n_epochs=5, n_groups_per_batch=2, no_group_logging=False, optimizer='Adam', optimizer_kwargs={}, progress_bar=False, resize_scale=None, resume=False, root_dir='data', save_best=True, save_last=True, save_step=None, scheduler=None, scheduler_kwargs={}, scheduler_metric_name=None, scheduler_metric_split='val', seed=0, split_scheme='official', target_resolution=None, train_loader='standard', train_transform=None, uniform_over_groups=False, use_wandb=False, val_metric='acc_avg', val_metric_decreasing=False, weight_decay=0.01)" + "device(type='cpu')" ] }, "execution_count": 11, @@ -818,7 +824,7 @@ } ], "source": [ - "config" + "y_true.device" ] }, { @@ -842,8 +848,8 @@ "\n", "Train:\n", "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (2432,) [1 0 1 ... 1 1 0] (2432,) 0.09923777357272781 tensor(0.0992, dtype=torch.float64)\n", - "why [False False False ... False False False] (1792,) [1 1 0 ... 1 0 1] (1792,) 0.18020602071676678 tensor(0.1802, dtype=torch.float64)\n", + "why [False False False ... False False False] (2432,) [1 0 1 ... 1 1 0] (2432,) 0.09923777357272781 tensor(0.0992, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1792,) [1 1 0 ... 1 0 1] (1792,) 0.18020602071676678 tensor(0.1802, device='cuda:0', dtype=torch.float64)\n", "why [False False False False False False False False False False False False\n", " False False False False False False False False False False False False\n", " False False False False False False False False False False False True\n", @@ -942,26 +948,1095 @@ " 1 0 1 1 0 0 0 0 1 0 0 0 1 1 1 1 0 0 0 1 1 0 1 1 0 0 1 0 0 1 1 1 1 0 0 0 1\n", " 0 1 0 1 1 0 1 0 0 0 1 1 1 1 1 1 0 1 0 1 1 0 1 0 0 1 1 1 1 0 1 1 0 1 1 1 1\n", " 1 1 0 0 0 0 0 1 0 1 0 0 0 1 0 1 0 0 1 0 1 1 0 1 0 1 1 1 0 1 1 1 1 0 0 1 1\n", - " 1 1 1 0 1 1 0 1] (896,) 0.12653340353855683 tensor(0.1265, dtype=torch.float64)\n", - "why [False False False ... False False False] (1152,) [0 0 0 ... 1 1 0] (1152,) 0.15009138463477656 tensor(0.1501, dtype=torch.float64)\n", - "why [ True True True ... True True True] (1920,) [0 0 1 ... 1 0 0] (1920,) 0.13893378955027236 tensor(0.1389, dtype=torch.float64)\n" + " 1 1 1 0 1 1 0 1] (896,) 0.12653340353855683 tensor(0.1265, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1152,) [0 0 0 ... 1 1 0] (1152,) 0.15009138463477656 tensor(0.1501, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... True True True] (1920,) [0 0 1 ... 1 0 0] (1920,) 0.13893378955027236 tensor(0.1389, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [1 0 1 ... 1 1 0] (8192,) 0.13583524260280033 tensor(0.1358, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... True True True] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.07954545454545454 tensor(0.0795, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... True True True] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.11778846153846154 tensor(0.1178, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.08138020833333333 tensor(0.0814, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.196875 tensor(0.1969, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (2304,) [0 0 0 ... 0 0 0] (2304,) 0.1623263888888889 tensor(0.1623, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... True True True] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.1292724609375 tensor(0.1293, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (2304,) [0 0 0 ... 0 0 0] (2304,) 0.09678819444444445 tensor(0.0968, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.19587053571428573 tensor(0.1959, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.1015625 tensor(0.1016, device='cuda:0', dtype=torch.float64)\n", + "why [False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False True True True\n", + " True True True True True True True True True True True False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False True True\n", + " True True True True False False False False False False True True\n", + " True True True True True True True True True True True True\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False True True True True True True True\n", + " True True True True True True False False False False False False\n", + " False False False False True True True True True True True True\n", + " True True True True True True False False False False False False\n", + " False True True True True True True True True True True True\n", + " True True True True True True True False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False] (512,) [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] (512,) 0.154296875 tensor(0.1543, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.142578125 tensor(0.1426, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.1318359375 tensor(0.1318, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (2432,) [0 0 0 ... 0 0 0] (2432,) 0.09580592105263158 tensor(0.0958, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.2506510416666667 tensor(0.2507, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1152,) [0 0 0 ... 0 0 0] (1152,) 0.09809027777777778 tensor(0.0981, device='cuda:0', dtype=torch.float64)\n", + "why [ True False False ... False False False] (1024,) [0 0 0 ... 0 0 0] (1024,) 0.2353515625 tensor(0.2354, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.109375 tensor(0.1094, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.14599609375 tensor(0.1460, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... True True True] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.10107421875 tensor(0.1011, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.20454545454545456 tensor(0.2045, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1024,) [0 0 0 ... 0 0 0] (1024,) 0.099609375 tensor(0.0996, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.19810267857142858 tensor(0.1981, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.10885416666666667 tensor(0.1089, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... True True True] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.1417236328125 tensor(0.1417, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.11221590909090909 tensor(0.1122, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.13040865384615385 tensor(0.1304, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... True True True] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.09588068181818182 tensor(0.0959, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.16471354166666666 tensor(0.1647, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (2176,) [0 0 0 ... 0 0 0] (2176,) 0.14935661764705882 tensor(0.1494, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.1328125 tensor(0.1328, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.10997596153846154 tensor(0.1100, device='cuda:0', dtype=torch.float64)\n", + "why [ True False False ... False False False] (1152,) [0 0 0 ... 0 0 0] (1152,) 0.1232638888888889 tensor(0.1233, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.13616071428571427 tensor(0.1362, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... True True True] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.184375 tensor(0.1844, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (2304,) [0 0 0 ... 0 0 0] (2304,) 0.14149305555555555 tensor(0.1415, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.1380615234375 tensor(0.1381, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.1 tensor(0.1000, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.15980113636363635 tensor(0.1598, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.11160714285714286 tensor(0.1116, device='cuda:0', dtype=torch.float64)\n", + "why [False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False True True True True True True True True True\n", + " True True True True True True True True True True True True\n", + " True True True True False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " True True True True True True True True True True True True\n", + " True False True True True True True True True True True True\n", + " True True True True False False False False False False True True\n", + " True True True True True True True True True True True True\n", + " True True True True False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False True True True\n", + " True True True True True True True True True True True False\n", + " False False False False True True True True True True True True\n", + " True True True True True True True True True True False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False True\n", + " True True True True True True True True True True True True\n", + " True False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " True True True False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False True True True True True True True True True True True\n", + " True True True False False True True True True True True True\n", + " True True True True True True True True True True True True\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False True True True True True\n", + " True True True True True True True True True False False False\n", + " False False False False False False False False False False False False] (768,) [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] (768,) 0.21614583333333334 tensor(0.2161, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (2944,) [0 0 0 ... 0 0 0] (2944,) 0.1328125 tensor(0.1328, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.135498046875 tensor(0.1355, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.13385416666666666 tensor(0.1339, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.15178571428571427 tensor(0.1518, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.1203125 tensor(0.1203, device='cuda:0', dtype=torch.float64)\n", + "why [False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False True True True True True True True True\n", + " True True True True True True False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False True True True True True True True True True\n", + " True True True True True False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False True True True True\n", + " False False True True True True True True True True False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False True True True True True True True True\n", + " True True True True True True False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " True True True True True True True True True True True True\n", + " True True False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " True True True True True True True True True True True True\n", + " True False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False] (896,) [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0] (896,) 0.09040178571428571 tensor(0.0904, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.1045673076923077 tensor(0.1046, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.1239013671875 tensor(0.1239, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (2176,) [0 0 0 ... 0 0 0] (2176,) 0.13051470588235295 tensor(0.1305, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.14609375 tensor(0.1461, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.12239583333333333 tensor(0.1224, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.14118303571428573 tensor(0.1412, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.16193181818181818 tensor(0.1619, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.13916015625 tensor(0.1392, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False True True] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.09801136363636363 tensor(0.0980, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (2816,) [0 0 0 ... 0 0 0] (2816,) 0.10404829545454546 tensor(0.1040, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.0875 tensor(0.0875, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.20099431818181818 tensor(0.2010, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.13984375 tensor(0.1398, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False True True] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.1226806640625 tensor(0.1227, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.11145833333333334 tensor(0.1115, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.181640625 tensor(0.1816, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... True True True] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.11484375 tensor(0.1148, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.1658653846153846 tensor(0.1659, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.16685267857142858 tensor(0.1669, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.1483154296875 tensor(0.1483, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (1152,) [0 0 0 ... 0 0 0] (1152,) 0.1302568958818959 tensor(0.1303, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.189453125 tensor(0.1895, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.11067708333333333 tensor(0.1107, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.23468815928270043 tensor(0.2347, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... True True True] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.13385416666666666 tensor(0.1339, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.16484484726123597 tensor(0.1648, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.11263020833333333 tensor(0.1126, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.15576171875 tensor(0.1558, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.11610949612403101 tensor(0.1161, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.17734375 tensor(0.1773, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.171875 tensor(0.1719, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.1438577872555272 tensor(0.1439, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... True True True] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.12439903846153846 tensor(0.1244, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.22088068181818182 tensor(0.2209, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.15223817567567566 tensor(0.1522, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.1356534090909091 tensor(0.1357, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (2176,) [0 0 0 ... 0 0 0] (2176,) 0.15503202814868278 tensor(0.1550, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... True True True] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.15613628135565832 tensor(0.1561, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.08984375 tensor(0.0898, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.14543269230769232 tensor(0.1454, device='cuda:0', dtype=torch.float64)\n", + "why [False True True ... False False False] (1152,) [0 0 0 ... 0 0 0] (1152,) 0.1545138888888889 tensor(0.1545, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.16601859327507598 tensor(0.1660, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.09486607142857142 tensor(0.0949, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.1307907754109508 tensor(0.1308, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [ True True True True True True True True True True True True\n", + " True True True False True True True True True True True True\n", + " True True True False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False True True True\n", + " True True True True True True True True True False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False True True True True True True\n", + " False False False True True True True True True False False False\n", + " False False False False False False False False False False True True\n", + " True True True True True True True True True True False False\n", + " False False False False False False False False False False False False\n", + " False True True True True True True True True True True True\n", + " True True True True False False False False True True True True\n", + " True True True True True True True False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " True True True True True True True True True True True True\n", + " True True False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False True True True True\n", + " True True True True True True True False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False True True True True True True True\n", + " True True True True False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False True True True True True True True\n", + " True True True True True False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False True True True True True True True True True True True\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " True True True True True True True True True True True False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False True True True True True True\n", + " True True True True True False False False False False False False\n", + " False False False False False False False False] (896,) [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0] (896,) 0.18861607142857142 tensor(0.1886, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (2560,) [0 0 0 ... 0 0 0] (2560,) 0.2031711368110236 tensor(0.2032, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.12560096153846154 tensor(0.1256, device='cuda:0', dtype=torch.float64)\n", + "why [False False False False False False True True True True True True\n", + " True True True True True True True True False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False True True True True True True True True\n", + " True True True True True False False False False False False False\n", + " False False False False False True True True True True True True\n", + " True True True True True True True False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False True True True True True\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False True True True True True True True True True\n", + " True True True True True True True True False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False True True True True True True True True True\n", + " True True True True False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False True True True True True True True True\n", + " True True True True True True False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False] (768,) [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] (768,) 0.1171875 tensor(0.1172, device='cuda:0', dtype=torch.float64)\n", + "why [False False True ... False False False] (2304,) [0 0 0 ... 0 0 0] (2304,) 0.14322916666666666 tensor(0.1432, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.16135470753972053 tensor(0.1614, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.13365334378265414 tensor(0.1337, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.14312537741545892 tensor(0.1431, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.1383054595896147 tensor(0.1383, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.1884765625 tensor(0.1885, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... True True True] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.0889423076923077 tensor(0.0889, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.1392934035570018 tensor(0.1393, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False True] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.1148314123790117 tensor(0.1148, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.234375 tensor(0.2344, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... True True True] (1152,) [0 0 0 ... 0 0 0] (1152,) 0.13518363161819538 tensor(0.1352, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... True True True] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.1484375 tensor(0.1484, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... True True True] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.20069918995700245 tensor(0.2007, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... True True True] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.16518310916225415 tensor(0.1652, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.14157774390243902 tensor(0.1416, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.16829982517482517 tensor(0.1683, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... True True True] (1536,) [0 0 0 ... 0 1 1] (1536,) 0.12203414351851852 tensor(0.1220, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.14015534682080924 tensor(0.1402, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.18266864778921865 tensor(0.1827, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.15107465864301803 tensor(0.1511, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.12805550230061352 tensor(0.1281, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.1798145077383275 tensor(0.1798, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... True True True] (1536,) [0 0 0 ... 1 0 0] (1536,) 0.14846865031897927 tensor(0.1485, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.2182291666666667 tensor(0.2182, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.17879293893129775 tensor(0.1788, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.17408185325186412 tensor(0.1741, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.13385180995475113 tensor(0.1339, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1152,) [0 0 0 ... 0 0 0] (1152,) 0.1634497549019608 tensor(0.1634, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (1152,) [0 0 0 ... 0 0 0] (1152,) 0.2111472315436242 tensor(0.2111, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.17961774553571427 tensor(0.1796, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (2176,) [0 0 0 ... 0 0 0] (2176,) 0.19505408546397282 tensor(0.1951, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.17720760641838973 tensor(0.1772, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.17260322523480418 tensor(0.1726, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.19631456413210446 tensor(0.1963, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.16002286585365852 tensor(0.1600, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... True True True] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.15676843030872636 tensor(0.1568, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.19980746809032893 tensor(0.1998, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.17694871945488722 tensor(0.1769, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... True True False] (3200,) [0 0 0 ... 0 0 0] (3200,) 0.17646062940470833 tensor(0.1765, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.21987862976406533 tensor(0.2199, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (1536,) [1 0 0 ... 0 0 0] (1536,) 0.22485079470618036 tensor(0.2249, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1152,) [0 0 0 ... 0 0 0] (1152,) 0.18892249103942654 tensor(0.1889, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1024,) [0 0 0 ... 0 0 0] (1024,) 0.20539447623239437 tensor(0.2054, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... True True False] (8192,) [1 0 0 ... 0 0 0] (8192,) 0.1956759851363835 tensor(0.1957, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [ True True True ... False False False] (2560,) [1 1 1 ... 0 0 0] (2560,) 0.16270833333333334 tensor(0.1627, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... True True True] (1536,) [0 0 0 ... 1 0 0] (1536,) 0.28461934747103557 tensor(0.2846, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.2885416666666667 tensor(0.2885, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.24493883087633087 tensor(0.2449, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1152,) [0 0 0 ... 0 0 0] (1152,) 0.1776813162682728 tensor(0.1777, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (8192,) [1 1 1 ... 0 0 0] (8192,) 0.22326946266948078 tensor(0.2233, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.19251085890430153 tensor(0.1925, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.22502709178398156 tensor(0.2250, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.22283878504672897 tensor(0.2228, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (1024,) [0 0 0 ... 0 0 0] (1024,) 0.2044723429144385 tensor(0.2045, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (2304,) [0 0 0 ... 0 0 0] (2304,) 0.1890666335978836 tensor(0.1891, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.1970471833881579 tensor(0.1970, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.17814201811043567 tensor(0.1781, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.22176106178589622 tensor(0.2218, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.15586979984301413 tensor(0.1559, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.19933712121212122 tensor(0.1993, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.27441314553990614 tensor(0.2744, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.20748284786370724 tensor(0.2075, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.1627858889528193 tensor(0.1628, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (1152,) [0 0 0 ... 0 0 0] (1152,) 0.3289409447955064 tensor(0.3289, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (1408,) [0 1 0 ... 0 0 0] (1408,) 0.25750782574670666 tensor(0.2575, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (2176,) [0 0 0 ... 0 0 0] (2176,) 0.2380265050832091 tensor(0.2380, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.20203645462301223 tensor(0.2020, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.2304055108248235 tensor(0.2304, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [ True True True ... False False False] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.18352952167414052 tensor(0.1835, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.391332129896404 tensor(0.3913, device='cuda:0', dtype=torch.float64)\n", + "why [False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False True\n", + " True True True True True True True True True True False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False True True True True True True True True\n", + " True True True True True True True True True True True True\n", + " True True True False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " True True True True True True True True True True True False\n", + " False True True True True True True True True True True True\n", + " True False True True True True True True True True True True\n", + " True False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False True True True True True True\n", + " True True True True True True False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False True True True True\n", + " True True True True True True True True False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False True True True True True True True True\n", + " True True False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " True True True True True True True True True True True True\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " True True True True True True True True True False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False] (896,) [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 1 1 1 1 1 1 1 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 1 1 1 0 1 1 0 0 0 0 0 1 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0] (896,) 0.43876971003366205 tensor(0.4388, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.26807482215447154 tensor(0.2681, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (2176,) [0 0 0 ... 0 0 0] (2176,) 0.2458394306739895 tensor(0.2458, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.2842815311314583 tensor(0.2843, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (1152,) [0 0 0 ... 0 0 0] (1152,) 0.24575731426692965 tensor(0.2458, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.276717519724741 tensor(0.2767, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.38168526600954644 tensor(0.3817, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True True True True True True True True True True\n", + " True True True True True True False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False True True True True True True\n", + " True True True True True True True True True True True False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False True True True True True\n", + " True True True True True True True True False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False True True True True True True True True True True\n", + " True True True True False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False True True True True True True True True\n", + " True True True True True True True True True True False False\n", + " False False False False False True True True True True True True\n", + " True True True True True True True True True True True False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False True True True\n", + " True True True True True True True True True True False False\n", + " False True True True True True True True True True True True\n", + " True True False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False True True\n", + " True True True True True True True True True True True True\n", + " True True True True True False False False False False False False\n", + " False False False False True True True True True True True True\n", + " True True True True False False True True True True True True\n", + " True True True True True True True True True True True True\n", + " True False False False False True True True True True True True\n", + " True True True True True True False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " True True True True True True True True True True True True\n", + " True True True True True True True True False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False True True True True True True True True\n", + " True True True True True True True True True True False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False] (896,) [0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 0 1 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 1\n", + " 1 1 1 1 1 1 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 1 0 1 1 1 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0] (896,) 0.3275530937683716 tensor(0.3276, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (2304,) [0 0 0 ... 0 0 0] (2304,) 0.24250047241118666 tensor(0.2425, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.27537596564595973 tensor(0.2754, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (2304,) [0 0 0 ... 0 0 0] (2304,) 0.339521139314602 tensor(0.3395, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.28316756119010217 tensor(0.2832, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1024,) [0 1 0 ... 0 0 0] (1024,) 0.30224860634648365 tensor(0.3022, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.23057474330872174 tensor(0.2306, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... True True True] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.22791799898259513 tensor(0.2279, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.27437629915291323 tensor(0.2744, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (2176,) [0 0 0 ... 0 0 0] (2176,) 0.21319969405140976 tensor(0.2132, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.3474399687036469 tensor(0.3474, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... True True True] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.19751082251082253 tensor(0.1975, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1152,) [0 0 0 ... 0 0 0] (1152,) 0.3353790123844628 tensor(0.3354, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... True True True] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.24501893939393937 tensor(0.2450, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... True True True] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.2624466475767001 tensor(0.2624, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (1408,) [1 1 1 ... 0 0 0] (1408,) 0.22450973341004987 tensor(0.2245, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... True False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.27231664754255114 tensor(0.2723, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1024,) [0 0 0 ... 0 0 0] (1024,) 0.3152901785714286 tensor(0.3153, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... True True True] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.35922695360195356 tensor(0.3592, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... True True True] (2432,) [0 0 0 ... 0 0 0] (2432,) 0.26736473289421736 tensor(0.2674, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [1 1 1 ... 0 0 0] (8192,) 0.28538833123099405 tensor(0.2854, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.245172509039775 tensor(0.2452, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.24340502699055327 tensor(0.2434, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.28707033026885964 tensor(0.2871, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (1024,) [0 0 0 ... 0 0 0] (1024,) 0.2957705135233918 tensor(0.2958, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.2895262781476896 tensor(0.2895, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.2656280862586716 tensor(0.2656, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... True True False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.19391790985177615 tensor(0.1939, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... True True True] (1152,) [0 0 0 ... 1 1 1] (1152,) 0.39839248075956224 tensor(0.3984, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.3400271739130435 tensor(0.3400, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.26218694096601075 tensor(0.2622, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (2304,) [0 0 0 ... 0 0 0] (2304,) 0.25949223766281415 tensor(0.2595, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.2731843170244799 tensor(0.2732, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (2432,) [0 0 0 ... 0 0 0] (2432,) 0.23153263758670284 tensor(0.2315, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.3294548915822105 tensor(0.3295, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1024,) [0 0 0 ... 0 0 0] (1024,) 0.50768331438611 tensor(0.5077, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.17941607556456285 tensor(0.1794, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.4005733735380117 tensor(0.4006, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.32525391000796444 tensor(0.3253, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.28327316031926486 tensor(0.2833, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (2688,) [0 0 0 ... 0 0 0] (2688,) 0.2455340291329215 tensor(0.2455, device='cuda:0', dtype=torch.float64)\n", + "why [False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False True True True True True True True\n", + " True True True True False False False False False False False False\n", + " False False False False False False False False False False True True\n", + " True True True True True True True True True True True True\n", + " True True False False False False False False False False False False\n", + " False False True True True True True True True True True True\n", + " True True False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False True True True True False False False False\n", + " False False False False False False False False False False False False\n", + " True True True True True True True True True True True True\n", + " True True True True True True True True True True True True\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False True True True True True True True True True True True\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False True True True\n", + " True True True True True True True True False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False True True True True True True True True\n", + " True True True False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False True True True True True True True True True True\n", + " True True False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False] (896,) [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1\n", + " 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 0 0 0\n", + " 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 0 0 0 1 1 1 1 1\n", + " 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0] (896,) 0.36439732142857145 tensor(0.3644, device='cuda:0', dtype=torch.float64)\n", + "why [False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False True True True True True False False False True\n", + " True True True True True True True True True True True True\n", + " True True True True True True True True True True True False\n", + " False False False False True True True True True True True True\n", + " True True True False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False True True True True True True True True\n", + " True True True True True True False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " True True True True True True True True True True True True\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False True True True True True True True True True True\n", + " True True True True False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False True True True True True True True\n", + " True True True True True True True False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False True True True True\n", + " True True True True True True True True True True True True\n", + " True True True True True True False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False True True True True True True\n", + " True True True True True True True True False False False False\n", + " False False False False False False False False False False False False] (768,) [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1\n", + " 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] (768,) 0.36334134615384617 tensor(0.3633, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.36328125 tensor(0.3633, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.3073375105806347 tensor(0.3073, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.42102430988608963 tensor(0.4210, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.39475473771436803 tensor(0.3948, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.3678160635096611 tensor(0.3678, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.19351388184584178 tensor(0.1935, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.24591191813804175 tensor(0.2459, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.3096451568959731 tensor(0.3096, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.24075629195519133 tensor(0.2408, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.17249526515151514 tensor(0.1725, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.2863095238095238 tensor(0.2863, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.4012790080941676 tensor(0.4013, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.32623064828253506 tensor(0.3262, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.29373969403168476 tensor(0.2937, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (2176,) [0 0 0 ... 0 0 0] (2176,) 0.3421500286608995 tensor(0.3422, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... True True True] (1664,) [0 0 0 ... 1 1 0] (1664,) 0.22848216513818703 tensor(0.2285, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.21294610507246378 tensor(0.2129, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.4324312010246706 tensor(0.4324, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.42839099459862173 tensor(0.4284, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... True True True] (8192,) [0 0 0 ... 1 1 0] (8192,) 0.3411826173375903 tensor(0.3412, device='cuda:0', dtype=torch.float64)\n" ] }, { - "ename": "RuntimeError", - "evalue": "All input tensors must be on the same device. Received cpu and cuda:0", + "ename": "KeyboardInterrupt", + "evalue": "", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0mbest_val_metric\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 25\u001b[0;31m train(\n\u001b[0m\u001b[1;32m 26\u001b[0m \u001b[0malgorithm\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0malgorithm\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0mdatasets\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdatasets\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/wilds/examples/train.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(algorithm, datasets, general_logger, config, epoch_offset, best_val_metric)\u001b[0m\n\u001b[1;32m 88\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 89\u001b[0m \u001b[0;31m# First run training\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 90\u001b[0;31m \u001b[0mrun_epoch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0malgorithm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdatasets\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'train'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgeneral_logger\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepoch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 91\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[0;31m# Then run val\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/wilds/examples/train.py\u001b[0m in \u001b[0;36mrun_epoch\u001b[0;34m(algorithm, dataset, general_logger, epoch, config, train)\u001b[0m\n\u001b[1;32m 41\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0miterator\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 43\u001b[0;31m \u001b[0mbatch_results\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0malgorithm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 44\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[0mbatch_results\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0malgorithm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mevaluate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/wilds/examples/algorithms/single_model_algorithm.py\u001b[0m in \u001b[0;36mupdate\u001b[0;34m(self, batch)\u001b[0m\n\u001b[1;32m 103\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_update\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 104\u001b[0m \u001b[0;31m# log results\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 105\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate_log\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 106\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msanitize_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 107\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/wilds/examples/algorithms/group_algorithm.py\u001b[0m in \u001b[0;36mupdate_log\u001b[0;34m(self, results)\u001b[0m\n\u001b[1;32m 47\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mm\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlogged_metrics\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mno_group_logging\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 49\u001b[0;31m group_metrics, group_counts, worst_group_metric = m.compute_group_wise(\n\u001b[0m\u001b[1;32m 50\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'y_pred'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'y_true'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/wilds/wilds/common/metrics/metric.py\u001b[0m in \u001b[0;36mcompute_group_wise\u001b[0;34m(self, y_pred, y_true, g, n_groups, return_dict)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mresults\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mdict\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mDictionary\u001b[0m \u001b[0mof\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 113\u001b[0m \"\"\"\n\u001b[0;32m--> 114\u001b[0;31m \u001b[0mgroup_metrics\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgroup_counts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mworst_group_metric\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compute_group_wise\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_groups\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 115\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mreturn_dict\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/wilds/wilds/common/metrics/metric.py\u001b[0m in \u001b[0;36m_compute_group_wise\u001b[0;34m(self, y_pred, y_true, g, n_groups)\u001b[0m\n\u001b[1;32m 134\u001b[0m \u001b[0my_pred\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mg\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mgroup_idx\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 135\u001b[0m y_true[g == group_idx]))\n\u001b[0;32m--> 136\u001b[0;31m \u001b[0mgroup_metrics\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstack\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgroup_metrics\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 137\u001b[0m \u001b[0mworst_group_metric\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mworst\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgroup_metrics\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mgroup_counts\u001b[0m\u001b[0;34m>\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 138\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mRuntimeError\u001b[0m: All input tensors must be on the same device. Received cpu and cuda:0" + "\u001b[0;32m~/wilds/examples/algorithms/single_model_algorithm.py\u001b[0m in \u001b[0;36mupdate\u001b[0;34m(self, batch)\u001b[0m\n\u001b[1;32m 101\u001b[0m \u001b[0;31m# process batch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 102\u001b[0m \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprocess_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 103\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_update\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 104\u001b[0m \u001b[0;31m# log results\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 105\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate_log\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/wilds/examples/algorithms/single_model_algorithm.py\u001b[0m in \u001b[0;36m_update\u001b[0;34m(self, results)\u001b[0m\n\u001b[1;32m 120\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax_grad_norm\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 121\u001b[0m \u001b[0mclip_grad_norm_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparameters\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax_grad_norm\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 122\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 123\u001b[0m self.step_schedulers(\n\u001b[1;32m 124\u001b[0m \u001b[0mis_epoch\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/site-packages/torch/autograd/grad_mode.py\u001b[0m in \u001b[0;36mdecorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mdecorate_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__class__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 26\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 27\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mcast\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mF\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdecorate_context\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/site-packages/torch/optim/adam.py\u001b[0m in \u001b[0;36mstep\u001b[0;34m(self, closure)\u001b[0m\n\u001b[1;32m 106\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 107\u001b[0m \u001b[0mbeta1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbeta2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgroup\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'betas'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 108\u001b[0;31m F.adam(params_with_grad,\n\u001b[0m\u001b[1;32m 109\u001b[0m \u001b[0mgrads\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 110\u001b[0m \u001b[0mexp_avgs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/site-packages/torch/optim/functional.py\u001b[0m in \u001b[0;36madam\u001b[0;34m(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, beta1, beta2, lr, weight_decay, eps)\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[0mdenom\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mmax_exp_avg_sq\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqrt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mmath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqrt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbias_correction2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0meps\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 93\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 94\u001b[0;31m \u001b[0mdenom\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mexp_avg_sq\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqrt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mmath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqrt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbias_correction2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0meps\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 95\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 96\u001b[0m \u001b[0mstep_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlr\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mbias_correction1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], diff --git a/wilds/common/metrics/all_metrics.py b/wilds/common/metrics/all_metrics.py index 73ef7b04..3d45a3d8 100644 --- a/wilds/common/metrics/all_metrics.py +++ b/wilds/common/metrics/all_metrics.py @@ -153,8 +153,7 @@ def _compute(self, y_pred, y_true): ypr, average=self.average ) - to_ret = torch.tensor(score)#.to(flattened_y_pred.device) - print("why ", ytr, ytr.shape, ypr, ypr.shape, score, to_ret) + to_ret = torch.tensor(score).to(y_pred.device) return to_ret def worst(self, metrics): From ba27784c2cc84f78b402ccec44d8d99a59c79bd5 Mon Sep 17 00:00:00 2001 From: aikanor Date: Sat, 20 Mar 2021 13:04:33 -0700 Subject: [PATCH 048/244] rebase --- examples/sbox_run_expt.ipynb | 20 ++++++++------------ wilds/common/metrics/all_metrics.py | 1 - wilds/common/metrics/metric.py | 2 +- 3 files changed, 9 insertions(+), 14 deletions(-) diff --git a/examples/sbox_run_expt.ipynb b/examples/sbox_run_expt.ipynb index 4eeaee7a..c4a25cc4 100644 --- a/examples/sbox_run_expt.ipynb +++ b/examples/sbox_run_expt.ipynb @@ -62,18 +62,9 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 14, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:root:The WILDS package is out of date. Your version is 1.0.0, while the latest version is 1.1.0.\n", - "WARNING:root:The OGB package is out of date. Your version is 1.2.4, while the latest version is 1.3.0.\n" - ] - } - ], + "outputs": [], "source": [ "import os, csv, sys\n", "os.environ['CUDA_VISIBLE_DEVICES'] = '4'\n", @@ -837,7 +828,12 @@ { "cell_type": "code", "execution_count": 12, - "metadata": {}, + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + } + }, "outputs": [ { "name": "stdout", diff --git a/wilds/common/metrics/all_metrics.py b/wilds/common/metrics/all_metrics.py index 3d45a3d8..e4d9681f 100644 --- a/wilds/common/metrics/all_metrics.py +++ b/wilds/common/metrics/all_metrics.py @@ -82,7 +82,6 @@ def _compute_flattened(self, flattened_y_pred, flattened_y_true): average=self.average ) to_ret = torch.tensor(score).to(flattened_y_pred.device) - print("why ", ytr, ytr.shape, ypr, ypr.shape, score, to_ret) return to_ret def _compute(self, y_pred, y_true): diff --git a/wilds/common/metrics/metric.py b/wilds/common/metrics/metric.py index 207dacf6..e0086e2a 100644 --- a/wilds/common/metrics/metric.py +++ b/wilds/common/metrics/metric.py @@ -232,7 +232,7 @@ def _compute(self, y_pred, y_true): def _compute_group_wise(self, y_pred, y_true, g, n_groups): flattened_metrics, indices = self.compute_flattened(y_pred, y_true, return_dict=False) flattened_g = g[indices] - print(flattened_metrics.shape, flattened_g.shape, (indices > 0).sum(), y_pred.shape, y_true.shape) + # print(flattened_metrics.shape, flattened_g.shape, (indices > 0).sum(), y_pred.shape, y_true.shape) group_metrics, group_counts = avg_over_groups(flattened_metrics, flattened_g, n_groups) worst_group_metric = self.worst(group_metrics[group_counts>0]) return group_metrics, group_counts, worst_group_metric From 1418d9de45dbe5c7a39acbb54deb53b9ff8acf40 Mon Sep 17 00:00:00 2001 From: aikanor Date: Sat, 20 Mar 2021 17:56:37 -0700 Subject: [PATCH 049/244] remove staging nb --- examples/sbox_run_expt.ipynb | 2158 ---------------------------------- 1 file changed, 2158 deletions(-) delete mode 100644 examples/sbox_run_expt.ipynb diff --git a/examples/sbox_run_expt.ipynb b/examples/sbox_run_expt.ipynb deleted file mode 100644 index c4a25cc4..00000000 --- a/examples/sbox_run_expt.ipynb +++ /dev/null @@ -1,2158 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# run_expt.py contents\n", - "\n", - "## 1) Preamble" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "47.42578125\n" - ] - } - ], - "source": [ - "import os, psutil; print(psutil.Process(os.getpid()).memory_info().rss / 1024 ** 2)" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "collapsed": true, - "jupyter": { - "outputs_hidden": true - } - }, - "outputs": [ - { - "ename": "NameError", - "evalue": "name 'bw' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# import pyBigWig\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;31m# %timeit bw = pyBigWig.open(\"/users/abalsubr/wilds/examples/data/encode-tfbs_v1.0/DNASE.K562.fc.signal.bigwig\")\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mget_ipython\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_line_magic\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'timeit'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"bw.values('chr1', 10000, 22800, numpy=True)\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/site-packages/IPython/core/interactiveshell.py\u001b[0m in \u001b[0;36mrun_line_magic\u001b[0;34m(self, magic_name, line, _stack_depth)\u001b[0m\n\u001b[1;32m 2334\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'local_ns'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_local_scope\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstack_depth\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2335\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbuiltin_trap\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2336\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2337\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2338\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m\u001b[0m in \u001b[0;36mtimeit\u001b[0;34m(self, line, cell, local_ns)\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/site-packages/IPython/core/magic.py\u001b[0m in \u001b[0;36m\u001b[0;34m(f, *a, **k)\u001b[0m\n\u001b[1;32m 185\u001b[0m \u001b[0;31m# but it's overkill for just that one bit of state.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 186\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mmagic_deco\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 187\u001b[0;31m \u001b[0mcall\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 188\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 189\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mcallable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/site-packages/IPython/core/magics/execution.py\u001b[0m in \u001b[0;36mtimeit\u001b[0;34m(self, line, cell, local_ns)\u001b[0m\n\u001b[1;32m 1167\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mindex\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1168\u001b[0m \u001b[0mnumber\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m10\u001b[0m \u001b[0;34m**\u001b[0m \u001b[0mindex\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1169\u001b[0;31m \u001b[0mtime_number\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtimer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtimeit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnumber\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1170\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtime_number\u001b[0m \u001b[0;34m>=\u001b[0m \u001b[0;36m0.2\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1171\u001b[0m \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/site-packages/IPython/core/magics/execution.py\u001b[0m in \u001b[0;36mtimeit\u001b[0;34m(self, number)\u001b[0m\n\u001b[1;32m 167\u001b[0m \u001b[0mgc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdisable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 168\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 169\u001b[0;31m \u001b[0mtiming\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minner\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mit\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtimer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 170\u001b[0m \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 171\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mgcold\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m\u001b[0m in \u001b[0;36minner\u001b[0;34m(_it, _timer)\u001b[0m\n", - "\u001b[0;31mNameError\u001b[0m: name 'bw' is not defined" - ] - } - ], - "source": [ - "# import pyBigWig\n", - "# %timeit bw = pyBigWig.open(\"/users/abalsubr/wilds/examples/data/encode-tfbs_v1.0/DNASE.K562.fc.signal.bigwig\")\n", - "%timeit bw.values('chr1', 10000, 22800, numpy=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [], - "source": [ - "import os, csv, sys\n", - "os.environ['CUDA_VISIBLE_DEVICES'] = '4'\n", - "\n", - "import time\n", - "import argparse\n", - "import numpy as np, pandas as pd\n", - "import torch\n", - "import torch.nn as nn\n", - "import torchvision\n", - "import pyBigWig\n", - "from collections import defaultdict\n", - "\n", - "from wilds.common.data_loaders import get_train_loader, get_eval_loader\n", - "from wilds.common.grouper import CombinatorialGrouper\n", - "\n", - "from utils import set_seed, Logger, BatchLogger, log_config, ParseKwargs, load, initialize_wandb, log_group_data, parse_bool\n", - "from train import train, evaluate\n", - "from algorithms.initializer import initialize_algorithm\n", - "from transforms import initialize_transform\n", - "from configs.utils import populate_defaults\n", - "import configs.supported as supported" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "''' set default hyperparams in default_hyperparams.py '''\n", - "parser = argparse.ArgumentParser()\n", - "\n", - "# Required arguments\n", - "parser.add_argument('-d', '--dataset', choices=supported.datasets, required=True)\n", - "parser.add_argument('--algorithm', required=True, choices=supported.algorithms)\n", - "parser.add_argument('--root_dir', required=True,\n", - " help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).')\n", - "\n", - "# Dataset\n", - "parser.add_argument('--split_scheme', help='Identifies how the train/val/test split is constructed. Choices are dataset-specific.')\n", - "parser.add_argument('--dataset_kwargs', nargs='*', action=ParseKwargs, default={})\n", - "parser.add_argument('--download', default=False, type=parse_bool, const=True, nargs='?',\n", - " help='If true, tries to downloads the dataset if it does not exist in root_dir.')\n", - "parser.add_argument('--frac', type=float, default=1.0,\n", - " help='Convenience parameter that scales all dataset splits down to the specified fraction, for development purposes.')\n", - "\n", - "# Loaders\n", - "parser.add_argument('--loader_kwargs', nargs='*', action=ParseKwargs, default={})\n", - "parser.add_argument('--train_loader', choices=['standard', 'group'])\n", - "parser.add_argument('--uniform_over_groups', type=parse_bool, const=True, nargs='?')\n", - "parser.add_argument('--distinct_groups', type=parse_bool, const=True, nargs='?')\n", - "parser.add_argument('--n_groups_per_batch', type=int)\n", - "parser.add_argument('--batch_size', type=int)\n", - "parser.add_argument('--eval_loader', choices=['standard'], default='standard')\n", - "\n", - "# Model\n", - "parser.add_argument('--model', choices=supported.models)\n", - "parser.add_argument('--model_kwargs', nargs='*', action=ParseKwargs, default={},\n", - " help='keyword arguments for model initialization passed as key1=value1 key2=value2')\n", - "\n", - "# Transforms\n", - "parser.add_argument('--train_transform', choices=supported.transforms)\n", - "parser.add_argument('--eval_transform', choices=supported.transforms)\n", - "parser.add_argument('--target_resolution', nargs='+', type=int, help='target resolution. for example --target_resolution 224 224 for standard resnet.')\n", - "parser.add_argument('--resize_scale', type=float)\n", - "parser.add_argument('--max_token_length', type=int)\n", - "\n", - "# Objective\n", - "parser.add_argument('--loss_function', choices = supported.losses)\n", - "\n", - "# Algorithm\n", - "parser.add_argument('--groupby_fields', nargs='+')\n", - "parser.add_argument('--group_dro_step_size', type=float)\n", - "parser.add_argument('--coral_penalty_weight', type=float)\n", - "parser.add_argument('--irm_lambda', type=float)\n", - "parser.add_argument('--irm_penalty_anneal_iters', type=int)\n", - "parser.add_argument('--algo_log_metric')\n", - "\n", - "# Model selection\n", - "parser.add_argument('--val_metric')\n", - "parser.add_argument('--val_metric_decreasing', type=parse_bool, const=True, nargs='?')\n", - "\n", - "# Optimization\n", - "parser.add_argument('--n_epochs', type=int)\n", - "parser.add_argument('--optimizer', choices=supported.optimizers)\n", - "parser.add_argument('--lr', type=float)\n", - "parser.add_argument('--weight_decay', type=float)\n", - "parser.add_argument('--max_grad_norm', type=float)\n", - "parser.add_argument('--optimizer_kwargs', nargs='*', action=ParseKwargs, default={})\n", - "\n", - "# Scheduler\n", - "parser.add_argument('--scheduler', choices=supported.schedulers)\n", - "parser.add_argument('--scheduler_kwargs', nargs='*', action=ParseKwargs, default={})\n", - "parser.add_argument('--scheduler_metric_split', choices=['train', 'val'], default='val')\n", - "parser.add_argument('--scheduler_metric_name')\n", - "\n", - "# Evaluation\n", - "parser.add_argument('--evaluate_all_splits', type=parse_bool, const=True, nargs='?', default=True)\n", - "parser.add_argument('--eval_splits', nargs='+', default=[])\n", - "parser.add_argument('--eval_only', type=parse_bool, const=True, nargs='?', default=False)\n", - "parser.add_argument('--eval_epoch', default=None, type=int)\n", - "\n", - "# Misc\n", - "parser.add_argument('--device', type=int, default=0)\n", - "parser.add_argument('--seed', type=int, default=0)\n", - "parser.add_argument('--log_dir', default='./logs')\n", - "parser.add_argument('--log_every', default=50, type=int)\n", - "parser.add_argument('--save_step', type=int)\n", - "parser.add_argument('--save_best', type=parse_bool, const=True, nargs='?', default=True)\n", - "parser.add_argument('--save_last', type=parse_bool, const=True, nargs='?', default=True)\n", - "parser.add_argument('--no_group_logging', type=parse_bool, const=True, nargs='?')\n", - "parser.add_argument('--use_wandb', type=parse_bool, const=True, nargs='?', default=False)\n", - "parser.add_argument('--progress_bar', type=parse_bool, const=True, nargs='?', default=False)\n", - "parser.add_argument('--resume', type=parse_bool, const=True, nargs='?', default=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "argstr_camelyon = \"--dataset camelyon17 --algorithm ERM --root_dir data\"\n", - "config_camelyon = parser.parse_args(argstr_camelyon.split())\n", - "config_camelyon = populate_defaults(config_camelyon)\n", - "\n", - "argstr_bdd100k = \"--dataset bdd100k --algorithm ERM --root_dir data\"\n", - "config_bdd100k = parser.parse_args(argstr_bdd100k.split())\n", - "config_bdd100k = populate_defaults(config_bdd100k)\n", - "\n", - "argstr_encode = \"--dataset encode-tfbs --algorithm ERM --root_dir data\"\n", - "config_encode = parser.parse_args(argstr_encode.split())\n", - "config_encode = populate_defaults(config_encode)\n", - "\n", - "config = config_camelyon\n", - "config = config_encode\n", - "# config = config_bdd100k\n" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Namespace(algo_log_metric=None, algorithm='ERM', batch_size=None, coral_penalty_weight=None, dataset='encode-tfbs', dataset_kwargs={}, device=0, distinct_groups=None, download=False, eval_epoch=None, eval_loader='standard', eval_only=False, eval_splits=[], eval_transform=None, evaluate_all_splits=True, frac=1.0, group_dro_step_size=None, groupby_fields=None, irm_lambda=None, irm_penalty_anneal_iters=None, loader_kwargs={'num_workers': 1, 'pin_memory': True}, log_dir='./logs', log_every=50, loss_function=None, lr=None, max_grad_norm=None, max_token_length=None, model=None, model_kwargs={'pretrained': False}, n_epochs=None, n_groups_per_batch=None, no_group_logging=None, optimizer=None, optimizer_kwargs={'momentum': 0.9}, progress_bar=False, resize_scale=None, resume=False, root_dir='data', save_best=True, save_last=True, save_step=None, scheduler=None, scheduler_kwargs={}, scheduler_metric_name=None, scheduler_metric_split='val', seed=0, split_scheme=None, target_resolution=None, train_loader=None, train_transform=None, uniform_over_groups=None, use_wandb=False, val_metric=None, val_metric_decreasing=None, weight_decay=None)" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "argstr_camelyon = \"--dataset camelyon17 --algorithm ERM --root_dir data\"\n", - "# argstr_camelyon = \"--dataset civilcomments --algorithm ERM --root_dir data\"\n", - "config_camelyon = parser.parse_args(argstr_camelyon.split())\n", - "\n", - "argstr_encode = \"--dataset encode-tfbs --algorithm ERM --root_dir data\"\n", - "config_encode = parser.parse_args(argstr_encode.split())\n", - "config_encode" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "config.optimizer_kwargs = {}" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Dataset: encode-tfbs\n", - "Algorithm: ERM\n", - "Root dir: data\n", - "Split scheme: official\n", - "Dataset kwargs: {}\n", - "Download: False\n", - "Frac: 1.0\n", - "Loader kwargs: {'num_workers': 1, 'pin_memory': True}\n", - "Train loader: standard\n", - "Uniform over groups: False\n", - "Distinct groups: None\n", - "N groups per batch: 2\n", - "Batch size: 64\n", - "Eval loader: standard\n", - "Model: leopard\n", - "Model kwargs: {'pretrained': False}\n", - "Train transform: None\n", - "Eval transform: None\n", - "Target resolution: None\n", - "Resize scale: None\n", - "Max token length: None\n", - "Loss function: multitask_bce\n", - "Groupby fields: ['celltype']\n", - "Group dro step size: None\n", - "Coral penalty weight: None\n", - "Irm lambda: None\n", - "Irm penalty anneal iters: None\n", - "Algo log metric: multitask_avgprec\n", - "Val metric: acc_avg\n", - "Val metric decreasing: False\n", - "N epochs: 5\n", - "Optimizer: Adam\n", - "Lr: 0.001\n", - "Weight decay: 0.01\n", - "Max grad norm: None\n", - "Optimizer kwargs: {}\n", - "Scheduler: None\n", - "Scheduler kwargs: {}\n", - "Scheduler metric split: val\n", - "Scheduler metric name: None\n", - "Evaluate all splits: True\n", - "Eval splits: []\n", - "Eval only: False\n", - "Eval epoch: None\n", - "Device: cuda:0\n", - "Seed: 0\n", - "Log dir: ./logs\n", - "Log every: 50\n", - "Save step: None\n", - "Save best: True\n", - "Save last: True\n", - "No group logging: False\n", - "Use wandb: False\n", - "Progress bar: False\n", - "Resume: False\n", - "\n", - "chr3 3.016324281692505\n", - "chr2 6.676640510559082\n", - "chr1 10.41373872756958\n" - ] - } - ], - "source": [ - "# set device\n", - "config.device = torch.device(\"cuda:\" + str(config.device)) if torch.cuda.is_available() else torch.device(\"cpu\")\n", - "\n", - "## Initialize logs\n", - "if os.path.exists(config.log_dir) and config.resume:\n", - " resume=True\n", - " mode='a'\n", - "elif os.path.exists(config.log_dir) and config.eval_only:\n", - " resume=False\n", - " mode='a'\n", - "else:\n", - " resume=False\n", - " mode='w'\n", - "\n", - "if not os.path.exists(config.log_dir):\n", - " os.makedirs(config.log_dir)\n", - "logger = Logger(os.path.join(config.log_dir, 'log.txt'), mode)\n", - "\n", - "# Record config\n", - "log_config(config, logger)\n", - "\n", - "# Set random seed\n", - "set_seed(config.seed)\n", - "\n", - "# Data\n", - "full_dataset = supported.datasets[config.dataset](\n", - " root_dir=config.root_dir,\n", - " download=config.download,\n", - " split_scheme=config.split_scheme,\n", - " **config.dataset_kwargs)\n", - "\n", - "# To implement data augmentation (i.e., have different transforms\n", - "# at training time vs. test time), modify these two lines:\n", - "train_transform = initialize_transform(\n", - " transform_name=config.train_transform,\n", - " config=config,\n", - " dataset=full_dataset)\n", - "eval_transform = initialize_transform(\n", - " transform_name=config.eval_transform,\n", - " config=config,\n", - " dataset=full_dataset)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 2) Initialize dataset object (trial version)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "collapsed": true, - "jupyter": { - "outputs_hidden": true, - "source_hidden": true - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "chr3 3.0039219856262207\n", - "chr4 5.89985990524292\n", - "chr5 8.640583038330078\n", - "chr6 11.237342596054077\n", - "chr7 13.666043519973755\n", - "chr10 15.858035326004028\n", - "chr12 17.94972252845764\n", - "chr13 19.689449071884155\n", - "chr14 21.30842876434326\n", - "chr15 22.856398582458496\n" - ] - }, - { - "ename": "KeyboardInterrupt", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[0m_seq_bp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 57\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mchrom\u001b[0m \u001b[0;32min\u001b[0m \u001b[0m_all_chroms\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 58\u001b[0;31m \u001b[0m_seq_bp\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mchrom\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mseq_arr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mchrom\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 59\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mchrom\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mitime\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/site-packages/numpy/lib/npyio.py\u001b[0m in \u001b[0;36m__getitem__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 252\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mmagic\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mformat\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mMAGIC_PREFIX\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 253\u001b[0m \u001b[0mbytes\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzip\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 254\u001b[0;31m return format.read_array(bytes,\n\u001b[0m\u001b[1;32m 255\u001b[0m \u001b[0mallow_pickle\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mallow_pickle\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 256\u001b[0m pickle_kwargs=self.pickle_kwargs)\n", - "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/site-packages/numpy/lib/format.py\u001b[0m in \u001b[0;36mread_array\u001b[0;34m(fp, allow_pickle, pickle_kwargs)\u001b[0m\n\u001b[1;32m 773\u001b[0m \u001b[0mread_count\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmax_read_count\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcount\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 774\u001b[0m \u001b[0mread_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mread_count\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitemsize\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 775\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_read_bytes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mread_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"array data\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 776\u001b[0m array[i:i+read_count] = numpy.frombuffer(data, dtype=dtype,\n\u001b[1;32m 777\u001b[0m count=read_count)\n", - "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/site-packages/numpy/lib/format.py\u001b[0m in \u001b[0;36m_read_bytes\u001b[0;34m(fp, size, error_template)\u001b[0m\n\u001b[1;32m 902\u001b[0m \u001b[0;31m# done about that. note that regular files can't be non-blocking\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 903\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 904\u001b[0;31m \u001b[0mr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msize\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 905\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mr\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 906\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mr\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0msize\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/zipfile.py\u001b[0m in \u001b[0;36mread\u001b[0;34m(self, n)\u001b[0m\n\u001b[1;32m 938\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_offset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 939\u001b[0m \u001b[0;32mwhile\u001b[0m \u001b[0mn\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m0\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_eof\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 940\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_read1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 941\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mn\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 942\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_readbuffer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/zipfile.py\u001b[0m in \u001b[0;36m_read1\u001b[0;34m(self, n)\u001b[0m\n\u001b[1;32m 1028\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_left\u001b[0m \u001b[0;34m<=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1029\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_eof\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1030\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_update_crc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1031\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1032\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/zipfile.py\u001b[0m in \u001b[0;36m_update_crc\u001b[0;34m(self, newdata)\u001b[0m\n\u001b[1;32m 953\u001b[0m \u001b[0;31m# No need to compute the CRC if we don't have a reference value\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 954\u001b[0m \u001b[0;32mreturn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 955\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_running_crc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcrc32\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnewdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_running_crc\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 956\u001b[0m \u001b[0;31m# Check the CRC if we're at the end of the file\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 957\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_eof\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_running_crc\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_expected_crc\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mKeyboardInterrupt\u001b[0m: " - ] - } - ], - "source": [ - "import os, time\n", - "import torch\n", - "import pandas as pd\n", - "import numpy as np\n", - "from wilds.datasets.wilds_dataset import WILDSDataset\n", - "from wilds.common.grouper import CombinatorialGrouper\n", - "from wilds.common.metrics.all_metrics import Accuracy\n", - "\n", - "root_dir='data'\n", - "download=False\n", - "split_scheme='official'\n", - "\n", - "itime = time.time()\n", - "_dataset_name = 'encode-tfbs'\n", - "_version = '1.0'\n", - "_download_url = 'https://worksheets.codalab.org/rest/bundles/0x8b3255e21e164cd98d3aeec09cd0bc26/contents/blob/'\n", - "_data_dir = 'data/encode-tfbs_v1.0/'\n", - "_y_size = 1\n", - "_n_classes = 2\n", - "\n", - "_train_chroms = ['chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr10', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX']\n", - "_val_chroms = ['chr2', 'chr9', 'chr11']\n", - "_test_chroms = ['chr1', 'chr8', 'chr21']\n", - "_transcription_factor = 'MAX'\n", - "_train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']\n", - "_val_celltype = ['A549']\n", - "_test_celltype = ['GM12878']\n", - "_all_chroms = _train_chroms + _val_chroms + _test_chroms\n", - "_all_celltypes = _train_celltypes + _val_celltype + _test_celltype\n", - "\n", - "_metadata_map = {}\n", - "_metadata_map['chr'] = _all_chroms\n", - "_metadata_map['celltype'] = _all_celltypes\n", - "\n", - "# Get the splits\n", - "if split_scheme=='official':\n", - " split_scheme = 'standard'\n", - "\n", - "_split_scheme = split_scheme\n", - "_split_dict = {\n", - " 'train': 0,\n", - " 'id_val': 1,\n", - " 'test': 2,\n", - " 'val': 3\n", - "}\n", - "_split_names = {\n", - " 'train': 'Train',\n", - " 'id_val': 'Validation (ID)',\n", - " 'test': 'Test',\n", - " 'val': 'Validation (OOD)',\n", - "}\n", - "\n", - "# Load sequence and DNase features\n", - "sequence_filename = os.path.join(_data_dir, 'sequence.npz')\n", - "seq_arr = np.load(sequence_filename)\n", - "_seq_bp = {}\n", - "for chrom in _all_chroms:\n", - " _seq_bp[chrom] = seq_arr[chrom]\n", - " print(chrom, time.time() - itime)\n", - "\n", - "_dnase_allcelltypes = {}\n", - "ct = 'avg'\n", - "dnase_avg_bw_path = os.path.join(_data_dir, 'Leopard_dnase/{}.bigwig'.format(ct))\n", - "_dnase_allcelltypes[ct] = pyBigWig.open(dnase_avg_bw_path)\n", - "for ct in _all_celltypes:\n", - " \"\"\"\n", - " dnase_filename = os.path.join(self._data_dir, '{}_dnase.npz'.format(ct))\n", - " dnase_npz_contents = np.load(dnase_filename)\n", - " self._dnase_allcelltypes[ct] = {}\n", - " for chrom in self._all_chroms: #self._seq_bp:\n", - " self._dnase_allcelltypes[ct][chrom] = dnase_npz_contents[chrom]\n", - " \"\"\"\n", - " dnase_bw_path = os.path.join(_data_dir, 'Leopard_dnase/{}.bigwig'.format(ct))\n", - " _dnase_allcelltypes[ct] = pyBigWig.open(dnase_bw_path)\n", - " print(ct, time.time() - itime)\n", - "\n", - "_metadata_df = pd.read_csv(\n", - " _data_dir + 'labels/MAX/metadata_df.bed', sep='\\t', header=None, \n", - " index_col=None, names=['chr', 'start', 'stop', 'celltype']\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "jupyter": { - "source_hidden": true - } - }, - "outputs": [], - "source": [ - "train_regions_mask = np.isin(_metadata_df['chr'], _train_chroms)\n", - "val_regions_mask = np.isin(_metadata_df['chr'], _val_chroms)\n", - "test_regions_mask = np.isin(_metadata_df['chr'], _test_chroms)\n", - "train_celltype_mask = np.isin(_metadata_df['celltype'], _train_celltypes)\n", - "val_celltype_mask = np.isin(_metadata_df['celltype'], _val_celltype)\n", - "test_celltype_mask = np.isin(_metadata_df['celltype'], _test_celltype)\n", - "\n", - "split_array = -1*np.ones(_metadata_df.shape[0]).astype(int)\n", - "split_array[np.logical_and(train_regions_mask, train_celltype_mask)] = _split_dict['train']\n", - "split_array[np.logical_and(test_regions_mask, test_celltype_mask)] = _split_dict['test']\n", - "# Validate using validation chr, either using a designated validation cell line ('val') or a training cell line ('id_val')\n", - "split_array[np.logical_and(val_regions_mask, val_celltype_mask)] = _split_dict['val']\n", - "split_array[np.logical_and(val_regions_mask, train_celltype_mask)] = _split_dict['id_val']\n", - "\n", - "if _split_scheme=='standard':\n", - " _metadata_df.insert(len(_metadata_df.columns), 'split', split_array)\n", - "else:\n", - " raise ValueError(f'Split scheme {_split_scheme} not recognized')\n", - "\n", - "metadata_mask = (_metadata_df['split'] != -1)\n", - "_metadata_df = _metadata_df[_metadata_df['split'] != -1]\n", - "\n", - "chr_ints = _metadata_df['chr'].replace(dict( [(y, x) for x, y in enumerate(_metadata_map['chr'])] )).values\n", - "celltype_ints = _metadata_df['celltype'].replace(dict( [(y, x) for x, y in enumerate(_metadata_map['celltype'])] )).values\n", - "_split_array = _metadata_df['split'].values\n", - "\n", - "_y_array = torch.Tensor(np.load(_data_dir + 'labels/MAX/metadata_y.npy'))\n", - "_y_array = _y_array[metadata_mask]\n", - "\n", - "_metadata_array = torch.stack(\n", - " (torch.LongTensor(chr_ints), \n", - " torch.LongTensor(celltype_ints)\n", - " ),\n", - " dim=1)\n", - "_metadata_fields = ['chr', 'celltype']" - ] - }, - { - "cell_type": "code", - "execution_count": 325, - "metadata": { - "jupyter": { - "source_hidden": true - } - }, - "outputs": [], - "source": [ - "def get_random_label_vec(\n", - " metadata_df, seed_chr, seed_celltype, seed_start, output_size=128\n", - "):\n", - " \"\"\"\n", - " Given a coordinate in a celltype, gets the labels of \n", - " the `output_size` 200bp bins from that coordinate onward. \n", - " \"\"\"\n", - " itime = time.time()\n", - " \n", - " # Extract regions from this chromosome in this celltype, to get a window of labels from\n", - " # print(time.time() - itime)\n", - " # chr_msk = np.array(metadata_df['chr']) == seed_region['chr']\n", - " # print(time.time() - itime)\n", - " # ct_msk = np.array(metadata_df['celltype']) == seed_region['celltype']\n", - " # mdf = metadata_df[chr_msk & ct_msk]\n", - " seq_size = output_size*50\n", - " mdf = metadata_df.loc[\n", - " (metadata_df['chr'] == seed_chr) & \n", - " (metadata_df['celltype'] == seed_celltype) & \n", - " (metadata_df['start'] >= seed_start) & \n", - " (metadata_df['stop'] < seed_start+seq_size)\n", - " ]\n", - " print(time.time() - itime)\n", - "\n", - " # Get labels\n", - " y_label_vec = np.zeros(output_size)\n", - " y_label_vec[(mdf['start'] - seed_start) // 50] = mdf['y']\n", - " return mdf, y_label_vec" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Initialize algorithm" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Train data...\n", - " celltype = H1-hESC: n = 5314\n", - " celltype = HCT116: n = 4759\n", - " celltype = HeLa-S3: n = 4635\n", - " celltype = HepG2: n = 4459\n", - " celltype = K562: n = 5169\n", - " celltype = A549: n = 0\n", - " celltype = GM12878: n = 0\n", - "Validation (ID) data...\n", - " celltype = H1-hESC: n = 6872\n", - " celltype = HCT116: n = 6315\n", - " celltype = HeLa-S3: n = 4219\n", - " celltype = HepG2: n = 8356\n", - " celltype = K562: n = 6538\n", - " celltype = A549: n = 0\n", - " celltype = GM12878: n = 0\n", - "Test data...\n", - " celltype = H1-hESC: n = 0\n", - " celltype = HCT116: n = 0\n", - " celltype = HeLa-S3: n = 0\n", - " celltype = HepG2: n = 0\n", - " celltype = K562: n = 0\n", - " celltype = A549: n = 0\n", - " celltype = GM12878: n = 4487\n", - "Validation (OOD) data...\n", - " celltype = H1-hESC: n = 0\n", - " celltype = HCT116: n = 0\n", - " celltype = HeLa-S3: n = 0\n", - " celltype = HepG2: n = 0\n", - " celltype = K562: n = 0\n", - " celltype = A549: n = 6728\n", - " celltype = GM12878: n = 0\n", - "Dout: 128\n" - ] - } - ], - "source": [ - "# config = config_encode\n", - "\n", - "train_grouper = CombinatorialGrouper(\n", - " dataset=full_dataset,\n", - " groupby_fields=config.groupby_fields)\n", - "\n", - "datasets = defaultdict(dict)\n", - "for split in full_dataset.split_dict.keys():\n", - " if split=='train':\n", - " transform = train_transform\n", - " verbose = True\n", - " elif split == 'val':\n", - " transform = eval_transform\n", - " verbose = True\n", - " else:\n", - " transform = eval_transform\n", - " verbose = False\n", - " # Get subset\n", - " datasets[split]['dataset'] = full_dataset.get_subset(\n", - " split,\n", - " frac=config.frac,\n", - " transform=transform)\n", - "\n", - " if split == 'train':\n", - " datasets[split]['loader'] = get_train_loader(\n", - " loader=config.train_loader,\n", - " dataset=datasets[split]['dataset'],\n", - " batch_size=config.batch_size,\n", - " uniform_over_groups=config.uniform_over_groups,\n", - " grouper=train_grouper,\n", - " distinct_groups=config.distinct_groups,\n", - " n_groups_per_batch=config.n_groups_per_batch,\n", - " **config.loader_kwargs)\n", - " else:\n", - " datasets[split]['loader'] = get_eval_loader(\n", - " loader=config.eval_loader,\n", - " dataset=datasets[split]['dataset'],\n", - " grouper=train_grouper,\n", - " batch_size=config.batch_size,\n", - " **config.loader_kwargs)\n", - "\n", - " # Set fields\n", - " datasets[split]['split'] = split\n", - " datasets[split]['name'] = full_dataset.split_names[split]\n", - " datasets[split]['verbose'] = verbose\n", - " # Loggers\n", - " # Loggers\n", - " datasets[split]['eval_logger'] = BatchLogger(\n", - " os.path.join(config.log_dir, f'{split}_eval.csv'), mode=mode, use_wandb=(config.use_wandb and verbose))\n", - " datasets[split]['algo_logger'] = BatchLogger(\n", - " os.path.join(config.log_dir, f'{split}_algo.csv'), mode=mode, use_wandb=(config.use_wandb and verbose))\n", - "\n", - " if config.use_wandb:\n", - " initialize_wandb(config)\n", - "\n", - "# Logging dataset info\n", - "if config.no_group_logging and full_dataset.is_classification and full_dataset.y_size==1:\n", - " log_grouper = CombinatorialGrouper(\n", - " dataset=full_dataset,\n", - " groupby_fields=['y'])\n", - "elif config.no_group_logging:\n", - " log_grouper = None\n", - "else:\n", - " log_grouper = train_grouper\n", - "log_group_data(datasets, log_grouper, logger)\n", - "\n", - "## Initialize algorithm\n", - "algorithm = initialize_algorithm(\n", - " config=config,\n", - " datasets=datasets,\n", - " train_grouper=train_grouper)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "for batch in datasets['train']['loader']:\n", - " x, y_true, metadata = batch\n", - " break\n", - "# x = torch.transpose(x, 1, 2)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor(0.7212, device='cuda:0', grad_fn=)" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "d = algorithm.process_batch(batch)\n", - "\n", - "a = algorithm.loss.compute(d['y_pred'], d['y_true'], return_dict=False)\n", - "a" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0.5, 0.5, 0.5],\n", - " [0. , 0. , 0. , ..., 0.5, 0.5, 1. ]], dtype=float32)" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "#np.unique(full_dataset._metadata_df['split'], return_counts=True)\n", - "y_true.squeeze().detach().numpy()" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "ename": "NameError", - "evalue": "name 'importlib' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m#import importlib\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mimportlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;31mNameError\u001b[0m: name 'importlib' is not defined" - ] - } - ], - "source": [ - "#import importlib\n", - "importlib.reload(train)" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "device(type='cpu')" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "y_true.device" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Train" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "collapsed": true, - "jupyter": { - "outputs_hidden": true - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Epoch [0]:\n", - "\n", - "Train:\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (2432,) [1 0 1 ... 1 1 0] (2432,) 0.09923777357272781 tensor(0.0992, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1792,) [1 1 0 ... 1 0 1] (1792,) 0.18020602071676678 tensor(0.1802, device='cuda:0', dtype=torch.float64)\n", - "why [False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False True\n", - " True True True True True True True True True True True False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False True\n", - " True True True True True True True True True True True False\n", - " False True True True True True True True True True True True\n", - " True True True True True False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False True True True True True True True True True\n", - " True True True False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False True True\n", - " True True True True True True True True True True False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False True True True True True True True True True\n", - " True True True False False False False False False False False False\n", - " False False False False True True True True True True True True\n", - " True True True False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False True True True True True True True\n", - " True True True True True False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False] (896,) [1 0 1 0 0 1 1 0 1 0 0 1 0 1 1 1 0 1 1 0 1 0 1 1 0 0 1 1 1 1 1 1 0 1 0 0 0\n", - " 1 0 1 1 0 0 0 1 1 1 0 1 0 1 0 0 0 1 0 0 0 0 0 0 0 0 0 1 0 0 1 0 0 1 0 0 0\n", - " 0 1 1 1 0 1 1 0 1 0 1 0 0 1 0 0 1 1 0 0 0 0 1 1 0 0 0 0 1 0 0 1 0 0 0 0 0\n", - " 1 1 0 0 1 1 0 1 0 1 0 0 1 0 1 0 1 1 0 0 0 1 1 1 1 1 0 1 0 0 0 1 0 0 1 1 0\n", - " 1 0 0 1 0 0 0 0 1 1 1 1 0 0 0 0 0 1 0 1 0 1 1 1 1 0 1 1 0 0 1 1 1 1 1 1 0\n", - " 0 0 1 1 1 1 1 1 1 1 0 0 1 0 1 0 0 0 1 1 1 0 1 1 1 1 1 1 0 1 0 1 1 0 1 1 0\n", - " 1 0 1 1 1 1 1 0 0 1 1 1 0 0 1 1 0 0 1 0 1 0 0 1 1 1 1 0 1 1 1 1 1 0 1 1 1\n", - " 1 0 0 0 1 0 0 1 0 1 1 1 0 1 1 1 0 0 1 1 0 0 0 1 0 0 1 1 0 1 0 0 0 1 0 0 0\n", - " 1 1 0 1 0 1 1 0 0 1 0 1 0 1 1 1 1 0 1 0 1 0 1 0 0 1 1 0 1 1 0 1 1 1 1 0 0\n", - " 1 1 0 1 0 0 1 1 1 0 0 0 1 0 0 0 0 0 1 0 0 1 1 0 0 0 0 0 1 1 0 1 1 1 1 0 1\n", - " 0 0 0 1 1 0 0 0 0 1 1 1 0 1 1 1 0 1 0 1 0 0 0 0 0 1 1 1 1 1 0 1 0 1 0 0 1\n", - " 1 0 0 1 1 0 0 0 1 1 1 1 0 1 1 0 1 1 1 0 0 1 0 1 0 0 1 1 0 0 0 0 0 0 1 0 0\n", - " 0 0 1 0 1 0 1 0 1 1 1 1 1 0 1 1 1 0 1 0 1 1 0 0 0 0 1 0 1 1 0 1 0 1 1 1 0\n", - " 0 1 0 1 1 1 0 0 1 0 0 1 0 1 1 0 1 0 1 0 0 1 0 1 1 0 1 0 1 0 0 1 1 1 0 1 0\n", - " 1 0 1 1 1 1 1 0 0 1 0 1 0 0 0 0 1 0 0 0 0 0 1 0 0 0 1 0 1 1 0 0 0 0 0 1 0\n", - " 1 1 0 0 0 0 1 1 0 1 0 0 0 0 1 0 1 1 0 1 1 1 0 1 1 0 0 0 0 1 1 1 0 1 0 1 1\n", - " 0 0 0 0 1 1 1 1 1 0 0 1 0 1 0 1 1 0 1 1 1 1 1 1 1 1 0 0 1 1 1 1 1 1 1 1 1\n", - " 1 0 1 1 1 0 1 0 0 0 0 1 0 0 1 1 1 0 1 1 1 1 0 0 0 0 0 1 1 1 0 0 1 0 1 0 0\n", - " 1 0 0 1 0 1 0 1 1 1 0 1 1 0 1 1 0 0 1 0 0 1 1 1 1 0 1 1 0 1 1 1 0 1 1 0 1\n", - " 0 0 1 1 1 1 0 1 0 0 1 1 1 0 1 1 1 1 0 0 1 0 1 0 0 0 1 1 0 1 0 0 1 0 1 0 0\n", - " 1 0 1 1 0 1 1 1 1 0 0 1 0 0 1 1 1 1 0 0 1 1 1 1 0 0 1 0 0 1 0 1 0 1 1 1 0\n", - " 1 0 1 1 0 0 0 0 1 0 0 0 1 1 1 1 0 0 0 1 1 0 1 1 0 0 1 0 0 1 1 1 1 0 0 0 1\n", - " 0 1 0 1 1 0 1 0 0 0 1 1 1 1 1 1 0 1 0 1 1 0 1 0 0 1 1 1 1 0 1 1 0 1 1 1 1\n", - " 1 1 0 0 0 0 0 1 0 1 0 0 0 1 0 1 0 0 1 0 1 1 0 1 0 1 1 1 0 1 1 1 1 0 0 1 1\n", - " 1 1 1 0 1 1 0 1] (896,) 0.12653340353855683 tensor(0.1265, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1152,) [0 0 0 ... 1 1 0] (1152,) 0.15009138463477656 tensor(0.1501, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... True True True] (1920,) [0 0 1 ... 1 0 0] (1920,) 0.13893378955027236 tensor(0.1389, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [1 0 1 ... 1 1 0] (8192,) 0.13583524260280033 tensor(0.1358, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... True True True] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.07954545454545454 tensor(0.0795, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... True True True] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.11778846153846154 tensor(0.1178, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.08138020833333333 tensor(0.0814, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.196875 tensor(0.1969, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (2304,) [0 0 0 ... 0 0 0] (2304,) 0.1623263888888889 tensor(0.1623, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... True True True] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.1292724609375 tensor(0.1293, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (2304,) [0 0 0 ... 0 0 0] (2304,) 0.09678819444444445 tensor(0.0968, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.19587053571428573 tensor(0.1959, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.1015625 tensor(0.1016, device='cuda:0', dtype=torch.float64)\n", - "why [False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False True True True\n", - " True True True True True True True True True True True False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False True True\n", - " True True True True False False False False False False True True\n", - " True True True True True True True True True True True True\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False True True True True True True True\n", - " True True True True True True False False False False False False\n", - " False False False False True True True True True True True True\n", - " True True True True True True False False False False False False\n", - " False True True True True True True True True True True True\n", - " True True True True True True True False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False] (512,) [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] (512,) 0.154296875 tensor(0.1543, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.142578125 tensor(0.1426, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.1318359375 tensor(0.1318, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (2432,) [0 0 0 ... 0 0 0] (2432,) 0.09580592105263158 tensor(0.0958, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.2506510416666667 tensor(0.2507, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1152,) [0 0 0 ... 0 0 0] (1152,) 0.09809027777777778 tensor(0.0981, device='cuda:0', dtype=torch.float64)\n", - "why [ True False False ... False False False] (1024,) [0 0 0 ... 0 0 0] (1024,) 0.2353515625 tensor(0.2354, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.109375 tensor(0.1094, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.14599609375 tensor(0.1460, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... True True True] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.10107421875 tensor(0.1011, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.20454545454545456 tensor(0.2045, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1024,) [0 0 0 ... 0 0 0] (1024,) 0.099609375 tensor(0.0996, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.19810267857142858 tensor(0.1981, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.10885416666666667 tensor(0.1089, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... True True True] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.1417236328125 tensor(0.1417, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.11221590909090909 tensor(0.1122, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.13040865384615385 tensor(0.1304, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... True True True] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.09588068181818182 tensor(0.0959, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.16471354166666666 tensor(0.1647, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (2176,) [0 0 0 ... 0 0 0] (2176,) 0.14935661764705882 tensor(0.1494, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.1328125 tensor(0.1328, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.10997596153846154 tensor(0.1100, device='cuda:0', dtype=torch.float64)\n", - "why [ True False False ... False False False] (1152,) [0 0 0 ... 0 0 0] (1152,) 0.1232638888888889 tensor(0.1233, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.13616071428571427 tensor(0.1362, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... True True True] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.184375 tensor(0.1844, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (2304,) [0 0 0 ... 0 0 0] (2304,) 0.14149305555555555 tensor(0.1415, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.1380615234375 tensor(0.1381, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.1 tensor(0.1000, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.15980113636363635 tensor(0.1598, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.11160714285714286 tensor(0.1116, device='cuda:0', dtype=torch.float64)\n", - "why [False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False True True True True True True True True True\n", - " True True True True True True True True True True True True\n", - " True True True True False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " True True True True True True True True True True True True\n", - " True False True True True True True True True True True True\n", - " True True True True False False False False False False True True\n", - " True True True True True True True True True True True True\n", - " True True True True False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False True True True\n", - " True True True True True True True True True True True False\n", - " False False False False True True True True True True True True\n", - " True True True True True True True True True True False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False True\n", - " True True True True True True True True True True True True\n", - " True False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " True True True False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False True True True True True True True True True True True\n", - " True True True False False True True True True True True True\n", - " True True True True True True True True True True True True\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False True True True True True\n", - " True True True True True True True True True False False False\n", - " False False False False False False False False False False False False] (768,) [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] (768,) 0.21614583333333334 tensor(0.2161, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (2944,) [0 0 0 ... 0 0 0] (2944,) 0.1328125 tensor(0.1328, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.135498046875 tensor(0.1355, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.13385416666666666 tensor(0.1339, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.15178571428571427 tensor(0.1518, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.1203125 tensor(0.1203, device='cuda:0', dtype=torch.float64)\n", - "why [False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False True True True True True True True True\n", - " True True True True True True False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False True True True True True True True True True\n", - " True True True True True False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False True True True True\n", - " False False True True True True True True True True False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False True True True True True True True True\n", - " True True True True True True False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " True True True True True True True True True True True True\n", - " True True False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " True True True True True True True True True True True True\n", - " True False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False] (896,) [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0] (896,) 0.09040178571428571 tensor(0.0904, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.1045673076923077 tensor(0.1046, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.1239013671875 tensor(0.1239, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (2176,) [0 0 0 ... 0 0 0] (2176,) 0.13051470588235295 tensor(0.1305, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.14609375 tensor(0.1461, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.12239583333333333 tensor(0.1224, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.14118303571428573 tensor(0.1412, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.16193181818181818 tensor(0.1619, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.13916015625 tensor(0.1392, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False True True] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.09801136363636363 tensor(0.0980, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (2816,) [0 0 0 ... 0 0 0] (2816,) 0.10404829545454546 tensor(0.1040, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.0875 tensor(0.0875, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.20099431818181818 tensor(0.2010, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.13984375 tensor(0.1398, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False True True] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.1226806640625 tensor(0.1227, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.11145833333333334 tensor(0.1115, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.181640625 tensor(0.1816, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... True True True] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.11484375 tensor(0.1148, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.1658653846153846 tensor(0.1659, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.16685267857142858 tensor(0.1669, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.1483154296875 tensor(0.1483, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (1152,) [0 0 0 ... 0 0 0] (1152,) 0.1302568958818959 tensor(0.1303, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.189453125 tensor(0.1895, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.11067708333333333 tensor(0.1107, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.23468815928270043 tensor(0.2347, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... True True True] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.13385416666666666 tensor(0.1339, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.16484484726123597 tensor(0.1648, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.11263020833333333 tensor(0.1126, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.15576171875 tensor(0.1558, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.11610949612403101 tensor(0.1161, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.17734375 tensor(0.1773, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.171875 tensor(0.1719, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.1438577872555272 tensor(0.1439, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... True True True] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.12439903846153846 tensor(0.1244, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.22088068181818182 tensor(0.2209, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.15223817567567566 tensor(0.1522, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.1356534090909091 tensor(0.1357, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (2176,) [0 0 0 ... 0 0 0] (2176,) 0.15503202814868278 tensor(0.1550, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... True True True] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.15613628135565832 tensor(0.1561, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.08984375 tensor(0.0898, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.14543269230769232 tensor(0.1454, device='cuda:0', dtype=torch.float64)\n", - "why [False True True ... False False False] (1152,) [0 0 0 ... 0 0 0] (1152,) 0.1545138888888889 tensor(0.1545, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.16601859327507598 tensor(0.1660, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.09486607142857142 tensor(0.0949, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.1307907754109508 tensor(0.1308, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [ True True True True True True True True True True True True\n", - " True True True False True True True True True True True True\n", - " True True True False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False True True True\n", - " True True True True True True True True True False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False True True True True True True\n", - " False False False True True True True True True False False False\n", - " False False False False False False False False False False True True\n", - " True True True True True True True True True True False False\n", - " False False False False False False False False False False False False\n", - " False True True True True True True True True True True True\n", - " True True True True False False False False True True True True\n", - " True True True True True True True False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " True True True True True True True True True True True True\n", - " True True False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False True True True True\n", - " True True True True True True True False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False True True True True True True True\n", - " True True True True False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False True True True True True True True\n", - " True True True True True False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False True True True True True True True True True True True\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " True True True True True True True True True True True False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False True True True True True True\n", - " True True True True True False False False False False False False\n", - " False False False False False False False False] (896,) [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0] (896,) 0.18861607142857142 tensor(0.1886, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (2560,) [0 0 0 ... 0 0 0] (2560,) 0.2031711368110236 tensor(0.2032, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.12560096153846154 tensor(0.1256, device='cuda:0', dtype=torch.float64)\n", - "why [False False False False False False True True True True True True\n", - " True True True True True True True True False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False True True True True True True True True\n", - " True True True True True False False False False False False False\n", - " False False False False False True True True True True True True\n", - " True True True True True True True False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False True True True True True\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False True True True True True True True True True\n", - " True True True True True True True True False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False True True True True True True True True True\n", - " True True True True False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False True True True True True True True True\n", - " True True True True True True False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False] (768,) [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] (768,) 0.1171875 tensor(0.1172, device='cuda:0', dtype=torch.float64)\n", - "why [False False True ... False False False] (2304,) [0 0 0 ... 0 0 0] (2304,) 0.14322916666666666 tensor(0.1432, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.16135470753972053 tensor(0.1614, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.13365334378265414 tensor(0.1337, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.14312537741545892 tensor(0.1431, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.1383054595896147 tensor(0.1383, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.1884765625 tensor(0.1885, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... True True True] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.0889423076923077 tensor(0.0889, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.1392934035570018 tensor(0.1393, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False True] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.1148314123790117 tensor(0.1148, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.234375 tensor(0.2344, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... True True True] (1152,) [0 0 0 ... 0 0 0] (1152,) 0.13518363161819538 tensor(0.1352, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... True True True] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.1484375 tensor(0.1484, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... True True True] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.20069918995700245 tensor(0.2007, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... True True True] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.16518310916225415 tensor(0.1652, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.14157774390243902 tensor(0.1416, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.16829982517482517 tensor(0.1683, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... True True True] (1536,) [0 0 0 ... 0 1 1] (1536,) 0.12203414351851852 tensor(0.1220, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.14015534682080924 tensor(0.1402, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.18266864778921865 tensor(0.1827, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.15107465864301803 tensor(0.1511, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.12805550230061352 tensor(0.1281, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.1798145077383275 tensor(0.1798, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... True True True] (1536,) [0 0 0 ... 1 0 0] (1536,) 0.14846865031897927 tensor(0.1485, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.2182291666666667 tensor(0.2182, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.17879293893129775 tensor(0.1788, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.17408185325186412 tensor(0.1741, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.13385180995475113 tensor(0.1339, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1152,) [0 0 0 ... 0 0 0] (1152,) 0.1634497549019608 tensor(0.1634, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (1152,) [0 0 0 ... 0 0 0] (1152,) 0.2111472315436242 tensor(0.2111, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.17961774553571427 tensor(0.1796, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (2176,) [0 0 0 ... 0 0 0] (2176,) 0.19505408546397282 tensor(0.1951, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.17720760641838973 tensor(0.1772, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.17260322523480418 tensor(0.1726, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.19631456413210446 tensor(0.1963, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.16002286585365852 tensor(0.1600, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... True True True] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.15676843030872636 tensor(0.1568, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.19980746809032893 tensor(0.1998, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.17694871945488722 tensor(0.1769, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... True True False] (3200,) [0 0 0 ... 0 0 0] (3200,) 0.17646062940470833 tensor(0.1765, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.21987862976406533 tensor(0.2199, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (1536,) [1 0 0 ... 0 0 0] (1536,) 0.22485079470618036 tensor(0.2249, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1152,) [0 0 0 ... 0 0 0] (1152,) 0.18892249103942654 tensor(0.1889, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1024,) [0 0 0 ... 0 0 0] (1024,) 0.20539447623239437 tensor(0.2054, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... True True False] (8192,) [1 0 0 ... 0 0 0] (8192,) 0.1956759851363835 tensor(0.1957, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [ True True True ... False False False] (2560,) [1 1 1 ... 0 0 0] (2560,) 0.16270833333333334 tensor(0.1627, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... True True True] (1536,) [0 0 0 ... 1 0 0] (1536,) 0.28461934747103557 tensor(0.2846, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.2885416666666667 tensor(0.2885, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.24493883087633087 tensor(0.2449, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1152,) [0 0 0 ... 0 0 0] (1152,) 0.1776813162682728 tensor(0.1777, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (8192,) [1 1 1 ... 0 0 0] (8192,) 0.22326946266948078 tensor(0.2233, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.19251085890430153 tensor(0.1925, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.22502709178398156 tensor(0.2250, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.22283878504672897 tensor(0.2228, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (1024,) [0 0 0 ... 0 0 0] (1024,) 0.2044723429144385 tensor(0.2045, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (2304,) [0 0 0 ... 0 0 0] (2304,) 0.1890666335978836 tensor(0.1891, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.1970471833881579 tensor(0.1970, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.17814201811043567 tensor(0.1781, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.22176106178589622 tensor(0.2218, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.15586979984301413 tensor(0.1559, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.19933712121212122 tensor(0.1993, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.27441314553990614 tensor(0.2744, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.20748284786370724 tensor(0.2075, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.1627858889528193 tensor(0.1628, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (1152,) [0 0 0 ... 0 0 0] (1152,) 0.3289409447955064 tensor(0.3289, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (1408,) [0 1 0 ... 0 0 0] (1408,) 0.25750782574670666 tensor(0.2575, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (2176,) [0 0 0 ... 0 0 0] (2176,) 0.2380265050832091 tensor(0.2380, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.20203645462301223 tensor(0.2020, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.2304055108248235 tensor(0.2304, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [ True True True ... False False False] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.18352952167414052 tensor(0.1835, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.391332129896404 tensor(0.3913, device='cuda:0', dtype=torch.float64)\n", - "why [False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False True\n", - " True True True True True True True True True True False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False True True True True True True True True\n", - " True True True True True True True True True True True True\n", - " True True True False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " True True True True True True True True True True True False\n", - " False True True True True True True True True True True True\n", - " True False True True True True True True True True True True\n", - " True False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False True True True True True True\n", - " True True True True True True False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False True True True True\n", - " True True True True True True True True False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False True True True True True True True True\n", - " True True False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " True True True True True True True True True True True True\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " True True True True True True True True True False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False] (896,) [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 1 1 1 1 1 1 1 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 1 1 1 0 1 1 0 0 0 0 0 1 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1\n", - " 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1\n", - " 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0] (896,) 0.43876971003366205 tensor(0.4388, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.26807482215447154 tensor(0.2681, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (2176,) [0 0 0 ... 0 0 0] (2176,) 0.2458394306739895 tensor(0.2458, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.2842815311314583 tensor(0.2843, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (1152,) [0 0 0 ... 0 0 0] (1152,) 0.24575731426692965 tensor(0.2458, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.276717519724741 tensor(0.2767, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.38168526600954644 tensor(0.3817, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True True True True True True True True True True\n", - " True True True True True True False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False True True True True True True\n", - " True True True True True True True True True True True False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False True True True True True\n", - " True True True True True True True True False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False True True True True True True True True True True\n", - " True True True True False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False True True True True True True True True\n", - " True True True True True True True True True True False False\n", - " False False False False False True True True True True True True\n", - " True True True True True True True True True True True False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False True True True\n", - " True True True True True True True True True True False False\n", - " False True True True True True True True True True True True\n", - " True True False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False True True\n", - " True True True True True True True True True True True True\n", - " True True True True True False False False False False False False\n", - " False False False False True True True True True True True True\n", - " True True True True False False True True True True True True\n", - " True True True True True True True True True True True True\n", - " True False False False False True True True True True True True\n", - " True True True True True True False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " True True True True True True True True True True True True\n", - " True True True True True True True True False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False True True True True True True True True\n", - " True True True True True True True True True True False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False] (896,) [0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 0 1 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 1\n", - " 1 1 1 1 1 1 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 1 0 1 1 1 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0] (896,) 0.3275530937683716 tensor(0.3276, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (2304,) [0 0 0 ... 0 0 0] (2304,) 0.24250047241118666 tensor(0.2425, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.27537596564595973 tensor(0.2754, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (2304,) [0 0 0 ... 0 0 0] (2304,) 0.339521139314602 tensor(0.3395, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.28316756119010217 tensor(0.2832, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1024,) [0 1 0 ... 0 0 0] (1024,) 0.30224860634648365 tensor(0.3022, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.23057474330872174 tensor(0.2306, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... True True True] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.22791799898259513 tensor(0.2279, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.27437629915291323 tensor(0.2744, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (2176,) [0 0 0 ... 0 0 0] (2176,) 0.21319969405140976 tensor(0.2132, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.3474399687036469 tensor(0.3474, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... True True True] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.19751082251082253 tensor(0.1975, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1152,) [0 0 0 ... 0 0 0] (1152,) 0.3353790123844628 tensor(0.3354, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... True True True] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.24501893939393937 tensor(0.2450, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... True True True] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.2624466475767001 tensor(0.2624, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (1408,) [1 1 1 ... 0 0 0] (1408,) 0.22450973341004987 tensor(0.2245, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... True False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.27231664754255114 tensor(0.2723, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1024,) [0 0 0 ... 0 0 0] (1024,) 0.3152901785714286 tensor(0.3153, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... True True True] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.35922695360195356 tensor(0.3592, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... True True True] (2432,) [0 0 0 ... 0 0 0] (2432,) 0.26736473289421736 tensor(0.2674, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [1 1 1 ... 0 0 0] (8192,) 0.28538833123099405 tensor(0.2854, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.245172509039775 tensor(0.2452, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.24340502699055327 tensor(0.2434, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.28707033026885964 tensor(0.2871, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (1024,) [0 0 0 ... 0 0 0] (1024,) 0.2957705135233918 tensor(0.2958, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.2895262781476896 tensor(0.2895, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.2656280862586716 tensor(0.2656, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... True True False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.19391790985177615 tensor(0.1939, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... True True True] (1152,) [0 0 0 ... 1 1 1] (1152,) 0.39839248075956224 tensor(0.3984, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.3400271739130435 tensor(0.3400, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.26218694096601075 tensor(0.2622, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (2304,) [0 0 0 ... 0 0 0] (2304,) 0.25949223766281415 tensor(0.2595, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.2731843170244799 tensor(0.2732, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (2432,) [0 0 0 ... 0 0 0] (2432,) 0.23153263758670284 tensor(0.2315, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.3294548915822105 tensor(0.3295, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1024,) [0 0 0 ... 0 0 0] (1024,) 0.50768331438611 tensor(0.5077, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.17941607556456285 tensor(0.1794, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.4005733735380117 tensor(0.4006, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.32525391000796444 tensor(0.3253, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.28327316031926486 tensor(0.2833, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (2688,) [0 0 0 ... 0 0 0] (2688,) 0.2455340291329215 tensor(0.2455, device='cuda:0', dtype=torch.float64)\n", - "why [False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False True True True True True True True\n", - " True True True True False False False False False False False False\n", - " False False False False False False False False False False True True\n", - " True True True True True True True True True True True True\n", - " True True False False False False False False False False False False\n", - " False False True True True True True True True True True True\n", - " True True False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False True True True True False False False False\n", - " False False False False False False False False False False False False\n", - " True True True True True True True True True True True True\n", - " True True True True True True True True True True True True\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False True True True True True True True True True True True\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False True True True\n", - " True True True True True True True True False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False True True True True True True True True\n", - " True True True False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False True True True True True True True True True True\n", - " True True False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False] (896,) [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1\n", - " 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 0 0 0\n", - " 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 0 0 0 1 1 1 1 1\n", - " 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0] (896,) 0.36439732142857145 tensor(0.3644, device='cuda:0', dtype=torch.float64)\n", - "why [False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False True True True True True False False False True\n", - " True True True True True True True True True True True True\n", - " True True True True True True True True True True True False\n", - " False False False False True True True True True True True True\n", - " True True True False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False True True True True True True True True\n", - " True True True True True True False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " True True True True True True True True True True True True\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False True True True True True True True True True True\n", - " True True True True False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False True True True True True True True\n", - " True True True True True True True False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False True True True True\n", - " True True True True True True True True True True True True\n", - " True True True True True True False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False True True True True True True\n", - " True True True True True True True True False False False False\n", - " False False False False False False False False False False False False] (768,) [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1\n", - " 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] (768,) 0.36334134615384617 tensor(0.3633, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.36328125 tensor(0.3633, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.3073375105806347 tensor(0.3073, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.42102430988608963 tensor(0.4210, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.39475473771436803 tensor(0.3948, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.3678160635096611 tensor(0.3678, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.19351388184584178 tensor(0.1935, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.24591191813804175 tensor(0.2459, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.3096451568959731 tensor(0.3096, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.24075629195519133 tensor(0.2408, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.17249526515151514 tensor(0.1725, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.2863095238095238 tensor(0.2863, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.4012790080941676 tensor(0.4013, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.32623064828253506 tensor(0.3262, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.29373969403168476 tensor(0.2937, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (2176,) [0 0 0 ... 0 0 0] (2176,) 0.3421500286608995 tensor(0.3422, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... True True True] (1664,) [0 0 0 ... 1 1 0] (1664,) 0.22848216513818703 tensor(0.2285, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.21294610507246378 tensor(0.2129, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.4324312010246706 tensor(0.4324, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.42839099459862173 tensor(0.4284, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... True True True] (8192,) [0 0 0 ... 1 1 0] (8192,) 0.3411826173375903 tensor(0.3412, device='cuda:0', dtype=torch.float64)\n" - ] - }, - { - "ename": "KeyboardInterrupt", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0mbest_val_metric\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 25\u001b[0;31m train(\n\u001b[0m\u001b[1;32m 26\u001b[0m \u001b[0malgorithm\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0malgorithm\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0mdatasets\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdatasets\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/wilds/examples/train.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(algorithm, datasets, general_logger, config, epoch_offset, best_val_metric)\u001b[0m\n\u001b[1;32m 88\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 89\u001b[0m \u001b[0;31m# First run training\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 90\u001b[0;31m \u001b[0mrun_epoch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0malgorithm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdatasets\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'train'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgeneral_logger\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepoch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 91\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[0;31m# Then run val\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/wilds/examples/train.py\u001b[0m in \u001b[0;36mrun_epoch\u001b[0;34m(algorithm, dataset, general_logger, epoch, config, train)\u001b[0m\n\u001b[1;32m 41\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0miterator\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 43\u001b[0;31m \u001b[0mbatch_results\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0malgorithm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 44\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[0mbatch_results\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0malgorithm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mevaluate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/wilds/examples/algorithms/single_model_algorithm.py\u001b[0m in \u001b[0;36mupdate\u001b[0;34m(self, batch)\u001b[0m\n\u001b[1;32m 101\u001b[0m \u001b[0;31m# process batch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 102\u001b[0m \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprocess_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 103\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_update\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 104\u001b[0m \u001b[0;31m# log results\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 105\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate_log\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/wilds/examples/algorithms/single_model_algorithm.py\u001b[0m in \u001b[0;36m_update\u001b[0;34m(self, results)\u001b[0m\n\u001b[1;32m 120\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax_grad_norm\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 121\u001b[0m \u001b[0mclip_grad_norm_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparameters\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax_grad_norm\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 122\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 123\u001b[0m self.step_schedulers(\n\u001b[1;32m 124\u001b[0m \u001b[0mis_epoch\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/site-packages/torch/autograd/grad_mode.py\u001b[0m in \u001b[0;36mdecorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mdecorate_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__class__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 26\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 27\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mcast\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mF\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdecorate_context\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/site-packages/torch/optim/adam.py\u001b[0m in \u001b[0;36mstep\u001b[0;34m(self, closure)\u001b[0m\n\u001b[1;32m 106\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 107\u001b[0m \u001b[0mbeta1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbeta2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgroup\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'betas'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 108\u001b[0;31m F.adam(params_with_grad,\n\u001b[0m\u001b[1;32m 109\u001b[0m \u001b[0mgrads\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 110\u001b[0m \u001b[0mexp_avgs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/site-packages/torch/optim/functional.py\u001b[0m in \u001b[0;36madam\u001b[0;34m(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, beta1, beta2, lr, weight_decay, eps)\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[0mdenom\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mmax_exp_avg_sq\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqrt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mmath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqrt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbias_correction2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0meps\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 93\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 94\u001b[0;31m \u001b[0mdenom\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mexp_avg_sq\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqrt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mmath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqrt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbias_correction2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0meps\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 95\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 96\u001b[0m \u001b[0mstep_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlr\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mbias_correction1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mKeyboardInterrupt\u001b[0m: " - ] - } - ], - "source": [ - "if not config.eval_only:\n", - " ## Load saved results if resuming\n", - " resume_success = False\n", - " if resume:\n", - " save_path = os.path.join(config.log_dir, 'last_model.pth')\n", - " if not os.path.exists(save_path):\n", - " epochs = [\n", - " int(file.split('_')[0])\n", - " for file in os.listdir(config.log_dir) if file.endswith('.pth')]\n", - " if len(epochs) > 0:\n", - " latest_epoch = max(epochs)\n", - " save_path = os.path.join(config.log_dir, f'{latest_epoch}_model.pth')\n", - " try:\n", - " prev_epoch, best_val_metric = load(algorithm, save_path)\n", - " epoch_offset = prev_epoch + 1\n", - " logger.write(f'Resuming from epoch {epoch_offset} with best val metric {best_val_metric}')\n", - " resume_success = True\n", - " except FileNotFoundError:\n", - " pass\n", - "\n", - " if resume_success == False:\n", - " epoch_offset=0\n", - " best_val_metric=None\n", - " \n", - " train(\n", - " algorithm=algorithm,\n", - " datasets=datasets,\n", - " general_logger=logger,\n", - " config=config,\n", - " epoch_offset=epoch_offset,\n", - " best_val_metric=best_val_metric)\n", - "else:\n", - " if config.eval_epoch is None:\n", - " eval_model_path = os.path.join(config.log_dir, 'best_model.pth')\n", - " else:\n", - " eval_model_path = os.path.join(config.log_dir, f'{config.eval_epoch}_model.pth')\n", - " best_epoch, best_val_metric = load(algorithm, eval_model_path)\n", - " if config.eval_epoch is None:\n", - " epoch = best_epoch\n", - " else:\n", - " epoch = config.eval_epoch\n", - " evaluate(\n", - " algorithm=algorithm,\n", - " datasets=datasets,\n", - " epoch=epoch,\n", - " general_logger=logger,\n", - " config=config)\n", - "\n", - "logger.close()\n", - "for split in datasets:\n", - " datasets[split]['eval_logger'].close()\n", - " datasets[split]['algo_logger'].close()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.5" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} From 733bdb10c2359867a676b3ee93751849d4143b9f Mon Sep 17 00:00:00 2001 From: aikanor Date: Mon, 1 Feb 2021 21:29:52 -0800 Subject: [PATCH 050/244] add rough cut models/data --- examples/models/CNN_genome.py | 209 +++++++++++++++++++++++++++ wilds/datasets/encodetfbs_dataset.py | 147 +++++++++++++++++++ 2 files changed, 356 insertions(+) create mode 100644 examples/models/CNN_genome.py create mode 100644 wilds/datasets/encodetfbs_dataset.py diff --git a/examples/models/CNN_genome.py b/examples/models/CNN_genome.py new file mode 100644 index 00000000..f0115322 --- /dev/null +++ b/examples/models/CNN_genome.py @@ -0,0 +1,209 @@ +import math +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Beagle(nn.Module): + """ + Neural net models over genomic sequence. + Input: + - sequence_length: int (default 1000) + - Shape: (N, 5, sequence_length, 1) with batch size N. + + Output: + - prediction (Tensor): float torch tensor of shape (N, ) + + TODO: Finish docstring. + """ + def __init__(self, args): + """ + Parameters + ---------- + sequence_length : int + n_genomic_features : int + """ + super(Beagle, self).__init__() + + self.dropout = args.dropout + self.num_cell_types = 1 + self.conv1 = nn.Conv2d(5, 300, (19, 1), stride = (1, 1), padding=(9,0)) + self.conv2 = nn.Conv2d(300, 200, (11, 1), stride = (1, 1), padding = (5,0)) + self.conv3 = nn.Conv2d(200, 200, (7, 1), stride = (1, 1), padding = (4,0)) + self.bn1 = nn.BatchNorm2d(300) + self.bn2 = nn.BatchNorm2d(200) + self.bn3 = nn.BatchNorm2d(200) + self.maxpool1 = nn.MaxPool2d((3, 1)) + self.maxpool2 = nn.MaxPool2d((4, 1)) + self.maxpool3 = nn.MaxPool2d((4, 1)) + + self.fc1 = nn.Linear(4200, 1000) + self.bn4 = nn.BatchNorm1d(1000) + + self.fc2 = nn.Linear(1000, 1000) + self.bn5 = nn.BatchNorm1d(1000) + + self.fc3 = nn.Linear(1000, self.num_cell_types) + + def forward(self, s): + s = s.permute(0, 2, 1).contiguous() # batch_size x 5 x 1000 + s = s.view(-1, 5, 1000, 1) # batch_size x 5 x 1000 x 1 [5 channels] + s = self.maxpool1(F.relu(self.bn1(self.conv1(s)))) # batch_size x 300 x 333 x 1 + s = self.maxpool2(F.relu(self.bn2(self.conv2(s)))) # batch_size x 200 x 83 x 1 + s = self.maxpool3(F.relu(self.bn3(self.conv3(s)))) # batch_size x 200 x 21 x 1 + s = s.view(-1, 4200) + conv_out = s + + s = F.dropout(F.relu(self.bn4(self.fc1(s))), p=self.dropout, training=self.training) # batch_size x 1000 + s = F.dropout(F.relu(self.bn5(self.fc2(s))), p=self.dropout, training=self.training) # batch_size x 1000 + + s = self.fc3(s) + + return s, conv_out + + + +#class MLP(nn.Module): +# """Just an MLP""" +# def __init__(self, n_inputs, n_outputs, width, depth, drop_out): +# super(MLP, self).__init__() +# +# self.input = nn.Linear(n_inputs, width) +# self.dropout = nn.Dropout(dropout) +# self.hiddens = nn.ModuleList([ +# nn.Linear(width,width) +# for _ in range(depth-2)]) +# self.output = nn.Linear(width, n_outputs) +# self.n_outputs = n_outputs +# +# def forward(self, x): +# x = self.input(x) +# x = self.dropout(x) +# x = F.relu(x) +# for hidden in self.hiddens: +# x = hidden(x) +# x = self.dropout(x) +# x = F.relu(x) +# x = self.output(x) +# return x + + +""" +DeepSEA architecture (Zhou & Troyanskaya, 2015). +Based on https://github.com/FunctionLab/selene/blob/master/models/deepsea.py +""" + +class DeepSEA(nn.Module): + def __init__(self, sequence_length, n_genomic_features): + """ + Parameters + ---------- + sequence_length : int + n_genomic_features : int + """ + super(DeepSEA, self).__init__() + conv_kernel_size = 8 + pool_kernel_size = 4 + + self.conv_net = nn.Sequential( + nn.Conv1d(4, 320, kernel_size=conv_kernel_size), + nn.ReLU(inplace=True), + nn.MaxPool1d( + kernel_size=pool_kernel_size, stride=pool_kernel_size), + nn.Dropout(p=0.2), + + nn.Conv1d(320, 480, kernel_size=conv_kernel_size), + nn.ReLU(inplace=True), + nn.MaxPool1d( + kernel_size=pool_kernel_size, stride=pool_kernel_size), + nn.Dropout(p=0.2), + + nn.Conv1d(480, 960, kernel_size=conv_kernel_size), + nn.ReLU(inplace=True), + nn.Dropout(p=0.5)) + + reduce_by = conv_kernel_size - 1 + pool_kernel_size = float(pool_kernel_size) + self.n_channels = int( + np.floor( + (np.floor( + (sequence_length - reduce_by) / pool_kernel_size) + - reduce_by) / pool_kernel_size) + - reduce_by) + self.classifier = nn.Sequential( + nn.Linear(960 * self.n_channels, n_genomic_features), + nn.ReLU(inplace=True), + nn.Linear(n_genomic_features, n_genomic_features), + nn.Sigmoid()) + + def forward(self, x): + """Forward propagation of a batch. + """ + out = self.conv_net(x) + reshape_out = out.view(out.size(0), 960 * self.n_channels) + predict = self.classifier(reshape_out) + return predict + +""" +def criterion(): + return nn.BCELoss() + +def get_optimizer(lr): + # The optimizer and the parameters with which to initialize the optimizer. At a later time, we initialize the optimizer by also passing in the model parameters (`model.parameters()`). We cannot initialize the optimizer until the model has been initialized. + return (torch.optim.SGD, {"lr": lr, "weight_decay": 1e-6, "momentum": 0.9}) +""" + + + +""" +DanQ architecture (Quang & Xie, 2016). +""" + +class DanQ(nn.Module): + def __init__(self, sequence_length, n_genomic_features): + """ + Parameters + ---------- + sequence_length : int + Input sequence length + n_genomic_features : int + Total number of features to predict + """ + super(DanQ, self).__init__() + self.nnet = nn.Sequential( + nn.Conv1d(4, 320, kernel_size=26), + nn.ReLU(inplace=True), + nn.MaxPool1d( + kernel_size=13, stride=13), + nn.Dropout(0.2)) + + self.bdlstm = nn.Sequential(nn.LSTM(320, 320, num_layers=1, batch_first=True, bidirectional=True)) + + self._n_channels = math.floor( + (sequence_length - 25) / 13) + self.classifier = nn.Sequential( + nn.Dropout(0.5), + nn.Linear(self._n_channels * 640, 925), + nn.ReLU(inplace=True), + nn.Linear(925, n_genomic_features), + nn.Sigmoid()) + + def forward(self, x): + """Forward propagation of a batch. + """ + out = self.nnet(x) + reshape_out = out.transpose(0, 1).transpose(0, 2) + out, _ = self.bdlstm(reshape_out) + out = out.transpose(0, 1) + reshape_out = out.contiguous().view( + out.size(0), 640 * self._n_channels) + predict = self.classifier(reshape_out) + return predict + +""" +def criterion(): + return nn.BCELoss() + +def get_optimizer(lr): + return (torch.optim.RMSprop, {"lr": lr}) +""" \ No newline at end of file diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py new file mode 100644 index 00000000..08276aa9 --- /dev/null +++ b/wilds/datasets/encodetfbs_dataset.py @@ -0,0 +1,147 @@ +import os +import torch +import pandas as pd +import numpy as np +from wilds.datasets.wilds_dataset import WILDSDataset +from wilds.common.grouper import CombinatorialGrouper +from wilds.common.metrics.eval_metric import Accuracy +from wilds.common.eval import standard_group_eval + +import IPython + +class EncodeTFBSDataset(WILDSDataset): + """ + EncodeTFBS dataset + Website: https://www.synapse.org/#!Synapse:syn6131484 + """ + + def __init__(self, root_dir, download, split_scheme): + self._dataset_name = 'encodeTFBS' + self._download_url = 'https://worksheets.codalab.org/rest/bundles/0x8b3255e21e164cd98d3aeec09cd0bc26/contents/blob/' + self._data_dir = self.initialize_data_dir(root_dir, download) + self._y_size = 1 + self._n_classes = 2 + + self._tr_chrs = ['chr2', 'chr9', 'chr11'] + self._te_chrs = ['chr1', 'chr8', 'chr21'] + self._transcription_factor = 'MAX' + self._train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562'] + self._val_celltype = ['A549'] + self._test_celltype = ['GM12878'] + self._all_celltypes = self._train_celltypes + self._val_celltype + self._test_celltype + + self._metadata_fields = ['chr', 'celltype', 'y'] + self._metadata_map = {} + self._metadata_map['chr'] = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX'] + self._metadata_map['celltype'] = self._all_celltypes + + # Load sequence and DNase features + sequence_filename = os.path.join(self._data_dir, 'sequence.npz') + seq_arr = np.load(sequence_filename) + self._seq_bp = {} + for chrom in seq_arr: + self._seq_bp[chrom] = seq_arr[chrom] + + self._dnase_allcelltypes = {} + for ct in self._all_celltypes: + dnase_filename = os.path.join(self._data_dir, '{}_dnase.npz'.format(ct)) + dnase_npz_file = np.load(dnase_filename) + self._dnase_allcelltypes[ct] = {} + for chrom in seq_bp: + self._dnase_allcelltypes[ct][chrom] = dnase_npz_file[chrom] + + # Read in metadata dataframe from training+validation data + train_chr = pd.read_csv(os.path.join(self._data_dir, 'labels/{}.train.labels.tsv.gz'.format(self._transcription_factor)), sep='\t') + val_chr = pd.read_csv(os.path.join(self._data_dir, 'labels/{}.val.labels.tsv.gz'.format(self._transcription_factor)), sep='\t') + training_df = train_chr[np.isin(train_chr['chr'], self._tr_chrs)] + val_df = val_chr[np.isin(val_chr['chr'], self._te_chrs)] + all_df = pd.concat([training_df, val_df]) + + # Filter by start/stop coordinate if needed + filter_msk = all_df['start'] >= 0 + filter_msk = all_df['start']%1000 == 0 + all_df = all_df[filter_msk] + + pd_list = [] + for ct in self._train_celltypes: + tc_chr = all_df[['chr', 'start', 'stop', ct]] + tc_chr.columns = ['chr', 'start', 'stop', 'y'] + tc_chr['celltype'] = ct + pd_list.append(tc_chr) + metadata_df = pd.concat(pd_list) + + # Get the y values, and remove ambiguous labels by default. + y_array = metadata_df['y'].replace({'U': 0, 'B': 1, 'A': -1}).values + non_ambig_mask = (y_array != -1) + metadata_df['y'] = y_array + self._metadata_df = metadata_df[non_ambig_mask] + self._y_array = torch.LongTensor(y_array[non_ambig_mask]) + + chr_ints = self._metadata_df['chr'].replace(dict( [(y, x) for x, y in enumerate(self._metadata_map['chr'])] )).values + celltype_ints = self._metadata_df['celltype'].replace(dict( [(y, x) for x, y in enumerate(self._metadata_map['celltype'])] )).values + self._metadata_array = torch.stack( + (torch.LongTensor(chr_ints), + torch.LongTensor(celltype_ints), + self._y_array), + dim=1) + + # Get the splits + # TODO Extract splits as encoded in split_scheme. Hardcoded here for now. + self._split_scheme = split_scheme + self._split_dict = { + 'train': 0, + 'val-id': 1, + 'test': 2, + 'val-ood': 3 + } + self._split_names = { + 'train': 'Train', + 'val-id': 'Validation (ID)', + 'test': 'Test', + 'val-ood': 'Validation (OOD)', + } + train_chr_mask = np.isin(self._metadata_df['chr'], self._tr_chrs) + val_chr_mask = np.isin(self._metadata_df['chr'], self._te_chrs) + train_celltype_mask = np.isin(self._metadata_df['celltype'], self._train_celltypes) + val_celltype_mask = np.isin(self._metadata_df['celltype'], self._val_celltype) + test_celltype_mask = np.isin(self._metadata_df['celltype'], self._test_celltype) + + split_array = -1*np.ones(self._metadata_df.shape[0]).astype(int) + split_array[np.logical_and(train_chr_mask, train_celltype_mask)] = self._split_dict['train'] + split_array[np.logical_and(val_chr_mask, test_celltype_mask)] = self._split_dict['test'] + # Validate using test chr, either using a designated validation cell line ('val-ood') or a training cell line ('val-id') + split_array[np.logical_and(val_chr_mask, val_celltype_mask)] = self._split_dict['val-ood'] + split_array[np.logical_and(val_chr_mask, train_celltype_mask)] = self._split_dict['val-id'] + if self._split_scheme=='standard': + self._metadata_df['split'] = split_array + self._split_array = split_array + else: + raise ValueError(f'Split scheme {self._split_scheme} not recognized') + self._eval_grouper = CombinatorialGrouper( + dataset=self, + groupby_fields=['celltype']) + self._metric = Auprc() + + super().__init__(root_dir, download, split_scheme) + + def get_input(self, idx): + """ + Returns x for a given idx. + Computes this from: + (1) sequence features in self._seq_bp + (2) DNase features in self._dnase_allcelltypes + (3) Metadata for the index (location along the genome with 1kb window width) + """ + this_metadata = self._metadata_df.iloc[idx, :] + flank_size = 500 + interval_start = this_metadata['start'] - flank_size + interval_end = this_metadata['stop'] + flank_size + dnase_this = _dnase_allcelltypes[this_metadata['celltype']][this_metadata['chr']][interval_start:interval_end] + seq_this = _seq_bp[this_metadata['chr']][interval_start:interval_end] + return np.column_stack([seq_this, dnase_this]) + + def eval(self, y_pred, y_true, metadata): + return standard_group_eval( + self._metric, + self._eval_grouper, + y_pred, y_true, metadata) From a65acbcd0b50296d04e2e5853c2f048f99b15a34 Mon Sep 17 00:00:00 2001 From: aikanor Date: Tue, 2 Feb 2021 10:56:31 -0800 Subject: [PATCH 051/244] model/dataset fetching in nb 1/ --- .../sandbox_data-checkpoint.ipynb | 952 ++++++++++++++++++ .../encode-tfbs/sandbox_data.ipynb | 952 ++++++++++++++++++ sandbox_model.ipynb | 876 ++++++++++++++++ 3 files changed, 2780 insertions(+) create mode 100644 dataset_preprocessing/encode-tfbs/.ipynb_checkpoints/sandbox_data-checkpoint.ipynb create mode 100644 dataset_preprocessing/encode-tfbs/sandbox_data.ipynb create mode 100644 sandbox_model.ipynb diff --git a/dataset_preprocessing/encode-tfbs/.ipynb_checkpoints/sandbox_data-checkpoint.ipynb b/dataset_preprocessing/encode-tfbs/.ipynb_checkpoints/sandbox_data-checkpoint.ipynb new file mode 100644 index 00000000..b2e74829 --- /dev/null +++ b/dataset_preprocessing/encode-tfbs/.ipynb_checkpoints/sandbox_data-checkpoint.ipynb @@ -0,0 +1,952 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Initialize dataset object" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "57.5289368629\n", + "65.2459537983\n" + ] + } + ], + "source": [ + "import numpy as np, pandas as pd, os, time\n", + "import torch, torchvision\n", + "\n", + "data_dir = '/oak/stanford/groups/akundaje/abalsubr/DREAM/wilds/codalab_archive/'\n", + "tf = 'MAX'\n", + "itime = time.time()\n", + "train_chr = pd.read_csv(os.path.join(data_dir, 'labels/{}.train.labels.tsv.gz'.format(tf)), sep='\\t')\n", + "print(time.time() - itime)\n", + "val_chr = pd.read_csv(os.path.join(data_dir, 'labels/{}.val.labels.tsv.gz'.format(tf)), sep='\\t')\n", + "print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']\n", + "val_celltype = ['A549']\n", + "test_celltype = ['GM12878']\n", + "all_celltypes = train_celltypes + val_celltype + test_celltype\n", + "\n", + "metadata_map = {}\n", + "metadata_map['chr'] = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX']\n", + "metadata_map['celltype'] = all_celltypes\n", + "\n", + "_split_dict = {\n", + " 'train': 0,\n", + " 'val-id': 1,\n", + " 'test': 2,\n", + " 'val-ood': 3\n", + "}\n", + "_split_names = {\n", + " 'train': 'Train',\n", + " 'val-id': 'Validation (ID)',\n", + " 'test': 'Test',\n", + " 'val-ood': 'Validation (OOD)',\n", + "}\n", + "_split_scheme = 'standard'" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.0467748641968\n", + "('chr1', 4.52302885055542)\n", + "('chr2', 8.645489931106567)\n", + "('chr3', 11.959153890609741)\n", + "('chr4', 15.15813684463501)\n", + "('chr5', 18.22238802909851)\n", + "('chr6', 21.19420099258423)\n", + "('chr7', 23.940655946731567)\n", + "('chr8', 26.415233850479126)\n", + "('chr9', 28.833614826202393)\n", + "('chr10', 31.08920383453369)\n", + "('chr11', 33.37020301818848)\n", + "('chr12', 35.98973989486694)\n", + "('chr13', 37.88540601730347)\n", + "('chr14', 39.68082284927368)\n", + "('chr15', 41.242313861846924)\n", + "('chr16', 42.74874496459961)\n", + "('chr17', 44.12280797958374)\n", + "('chr18', 45.46893382072449)\n", + "('chr19', 46.50577902793884)\n", + "('chr20', 47.59563183784485)\n", + "('chr21', 48.31779384613037)\n", + "('chr22', 49.17265295982361)\n", + "('chrX', 51.75806999206543)\n", + "('H1-hESC', 25.880441904067993)\n", + "('HCT116', 50.130937814712524)\n", + "('HeLa-S3', 75.29559993743896)\n", + "('HepG2', 102.25979495048523)\n", + "('K562', 128.43050694465637)\n", + "('A549', 154.80679488182068)\n", + "('GM12878', 182.0279529094696)\n" + ] + } + ], + "source": [ + "itime = time.time()\n", + "sequence_filename = os.path.join(data_dir, 'sequence.npz')\n", + "seq_arr = np.load(sequence_filename)\n", + "print(time.time() - itime)\n", + "\n", + "itime = time.time()\n", + "_seq_bp = {}\n", + "for chrom in seq_arr:\n", + " _seq_bp[chrom] = seq_arr[chrom]\n", + " print(chrom, time.time() - itime)\n", + "\n", + "itime = time.time()\n", + "_dnase_allcelltypes = {}\n", + "for ct in all_celltypes:\n", + " dnase_filename = os.path.join(data_dir, '{}_dnase.npz'.format(ct))\n", + " dnase_npz_file = np.load(dnase_filename)\n", + " _dnase_allcelltypes[ct] = {}\n", + " for chrom in _seq_bp:\n", + " _dnase_allcelltypes[ct][chrom] = dnase_npz_file[chrom]\n", + " print(ct, time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'all_df' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# len(_dnase_allcelltypes)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mall_df\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mNameError\u001b[0m: name 'all_df' is not defined" + ] + } + ], + "source": [ + "# len(_dnase_allcelltypes)\n", + "all_df" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "ename": "AttributeError", + "evalue": "'module' object has no attribute 'isin'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mtr_chrs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m'chr2'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr9'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr11'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mte_chrs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m'chr1'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr8'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr21'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mtraining_df\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'chr'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtr_chrs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mval_df\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mval_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mval_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'chr'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mte_chrs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mall_df\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtraining_df\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_df\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mAttributeError\u001b[0m: 'module' object has no attribute 'isin'" + ] + } + ], + "source": [ + "tr_chrs = ['chr2', 'chr9', 'chr11']\n", + "te_chrs = ['chr1', 'chr8', 'chr21']\n", + "training_df = train_chr[np.isin(train_chr['chr'], tr_chrs)]\n", + "val_df = val_chr[np.isin(val_chr['chr'], te_chrs)]\n", + "all_df = pd.concat([training_df, val_df])\n", + "\n", + "#filter_msk = all_df['start'] >= 0\n", + "filter_msk = all_df['start']%1000 == 0\n", + "all_df = all_df[filter_msk]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "itime = time.time()\n", + "pd_list = []\n", + "for ct in all_celltypes:\n", + " tc_chr = all_df[['chr', 'start', 'stop', ct]]\n", + " tc_chr.columns = ['chr', 'start', 'stop', 'y']\n", + " tc_chr['celltype'] = ct\n", + " pd_list.append(tc_chr)\n", + "metadata_df = pd.concat(pd_list)\n", + "print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "itime = time.time()\n", + "y_array = metadata_df['y'].replace({'U': 0, 'B': 1, 'A': -1}).values\n", + "non_ambig_mask = (y_array != -1)\n", + "metadata_df['y'] = y_array\n", + "_metadata_df = metadata_df[non_ambig_mask]\n", + "_y_array = torch.LongTensor(y_array[non_ambig_mask])\n", + "print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "itime = time.time()\n", + "chr_ints = _metadata_df['chr'].replace(dict( [(y, x) for x, y in enumerate(metadata_map['chr'])] )).values\n", + "celltype_ints = _metadata_df['celltype'].replace(dict( [(y, x) for x, y in enumerate(metadata_map['celltype'])] )).values\n", + "print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_chr_mask = np.isin(_metadata_df['chr'], tr_chrs)\n", + "val_chr_mask = np.isin(_metadata_df['chr'], te_chrs)\n", + "train_celltype_mask = np.isin(_metadata_df['celltype'], train_celltypes)\n", + "val_celltype_mask = np.isin(_metadata_df['celltype'], val_celltype)\n", + "test_celltype_mask = np.isin(_metadata_df['celltype'], test_celltype)\n", + "\n", + "split_array = -1*np.ones(_metadata_df.shape[0]).astype(int)\n", + "split_array[np.logical_and(train_chr_mask, train_celltype_mask)] = _split_dict['train']\n", + "split_array[np.logical_and(val_chr_mask, test_celltype_mask)] = _split_dict['test']\n", + "split_array[np.logical_and(val_chr_mask, val_celltype_mask)] = _split_dict['val-ood']\n", + "split_array[np.logical_and(val_chr_mask, train_celltype_mask)] = _split_dict['val-id']\n", + "_metadata_df['split'] = split_array\n", + "_split_array = split_array" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "ename": "ImportError", + "evalue": "No module named data", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mDataLoader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mdataset_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mImportError\u001b[0m: No module named data" + ] + } + ], + "source": [ + "from torch.utils.data import DataLoader\n", + "from data import dataset_attributes" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "from PIL import Image\n", + "import argparse\n", + "class ParseKwargs(argparse.Action):\n", + " def __call__(self, parser, namespace, values, option_string=None):\n", + " setattr(namespace, self.dest, dict())\n", + " for value in values:\n", + " key, value_str = value.split('=')\n", + " if value_str.replace('-','').isnumeric():\n", + " processed_val = int(value_str)\n", + " elif value_str.replace('-','').replace('.','').isnumeric():\n", + " processed_val = float(value_str)\n", + " elif value_str in ['True', 'true']:\n", + " processed_val = True\n", + " elif value_str in ['False', 'false']:\n", + " processed_val = False\n", + " else:\n", + " processed_val = value_str\n", + " getattr(namespace, self.dest)[key] = processed_val" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'algorithm_constructors' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[0;31m# Algorithm and objective\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 33\u001b[0;31m \u001b[0mparser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_argument\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'--algorithm'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrequired\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mchoices\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0malgorithm_constructors\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 34\u001b[0m \u001b[0mparser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_argument\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'--algorithm_kwargs'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnargs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'*'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maction\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mParseKwargs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdefault\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 35\u001b[0m \u001b[0mparser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_argument\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'--groupby_fields'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnargs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'+'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdefault\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mNameError\u001b[0m: name 'algorithm_constructors' is not defined" + ] + } + ], + "source": [ + "ROOTDIR = '/oak/stanford/groups/akundaje/abalsubr/wilds_other'\n", + "args_kw = \"-d camelyon17 --algorithm ERM --model densenet121 --split_scheme standard --groupby_fields hospital --loss_function cross_entropy --optimizer SGD --lr 0.0001 --batch_size 32 --weight_decay 0 --n_epochs 10 --scheduler ReduceLROnPlateau --scheduler_metric_split val --scheduler_metric_name acc_avg --log_dir log --log_every 50 --save_step 1000 --save_best --save_last --seed 0 --evaluate_all_splits --root_dir {}\".format(\n", + " ROOTDIR).split()\n", + "\n", + "parser = argparse.ArgumentParser()\n", + "\n", + "# Dataset\n", + "parser.add_argument('-d', '--dataset', choices=['encodeTFBS', 'amazon', 'camelyon17', 'celebA', 'civilcomments', 'iwildcam', 'waterbirds', 'yelp', 'poverty', 'fmow', 'ogbg-molpcba'], required=True)\n", + "parser.add_argument('--split_scheme', default='standard',\n", + " help='Identifies how the train/val/test split is constructed. Choices are dataset-specific.')\n", + "parser.add_argument('--dataset_kwargs', nargs='*', action=ParseKwargs, default={})\n", + "parser.add_argument('--root_dir', default=None, required=True,\n", + " help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).')\n", + "parser.add_argument('--download', default=False, action='store_true',\n", + " help='If true, tries to downloads the dataset if it does not exist in root_dir.')\n", + "parser.add_argument('--frac', type=float, default=1.0,\n", + " help='Convenience parameter that scales all dataset splits down to the specified fraction, for development purposes.')\n", + "\n", + "# Loaders\n", + "parser.add_argument('--train_loader', choices=['standard', 'group'], default='standard')\n", + "parser.add_argument('--train_loader_kwargs', nargs='*', action=ParseKwargs, default={})\n", + "parser.add_argument('--eval_loader', choices=['standard', 'group'], default='standard')\n", + "parser.add_argument('--batch_size', type=int, default=32)\n", + "parser.add_argument('--no_pin_memory', action='store_true') # TODO: put as loader_kwargs\n", + "parser.add_argument('--num_workers', type=int, default=4) # TODO: put as loader kwargs\n", + "\n", + "# Model\n", + "parser.add_argument(\n", + " '--model',\n", + " choices=['bert-base-uncased', 'inception_v3', 'densenet121', 'wideresnet50', 'resnet50', 'gin-virtual', 'resnet18_ms'],\n", + " default='resnet50')\n", + "parser.add_argument('--model_kwargs', nargs='*', action=ParseKwargs, default={},\n", + " help='keyword arguments for model initialization passed as key1=value1 key2=value2')\n", + "parser.add_argument('--train_from_scratch', action='store_true', default=False)\n", + "\n", + "# Algorithm and objective\n", + "parser.add_argument('--algorithm', required=True, choices=algorithm_constructors.keys())\n", + "parser.add_argument('--algorithm_kwargs', nargs='*', action=ParseKwargs, default={})\n", + "parser.add_argument('--groupby_fields', nargs='+', default=None)\n", + "parser.add_argument('--loss_function', required=True, choices = losses.keys()) #TODO: make default\n", + "parser.add_argument('--val_metric', default=None)\n", + "\n", + "# Optimization\n", + "parser.add_argument('--n_epochs', type=int, default=4)\n", + "parser.add_argument('--optimizer', default=None, choices=optimizer_attributes.keys())\n", + "parser.add_argument('--lr', type=float, required=True)\n", + "parser.add_argument('--weight_decay', type=float, required=True)\n", + "parser.add_argument('--optimizer_kwargs', nargs='*', action=ParseKwargs, default={})\n", + "parser.add_argument('--scheduler', default=None, choices=scheduler_attributes.keys())\n", + "parser.add_argument('--scheduler_kwargs', nargs='*', action=ParseKwargs, default={})\n", + "parser.add_argument('--scheduler_metric_split', choices=['train', 'val'], default='val')\n", + "parser.add_argument('--scheduler_metric_name')\n", + "\n", + "# Evaluation\n", + "parser.add_argument('--evaluate_all_splits', action='store_true', default=False)\n", + "parser.add_argument('--additional_eval_splits', nargs='+', default=[])\n", + "\n", + "# Misc\n", + "parser.add_argument('--device', type=int, default=0)\n", + "parser.add_argument('--seed', type=int, default=0)\n", + "parser.add_argument('--log_dir', default='./logs')\n", + "parser.add_argument('--log_every', default=50, type=int)\n", + "parser.add_argument('--save_step', type=int, default=None)\n", + "parser.add_argument('--save_best', action='store_true', default=False)\n", + "parser.add_argument('--save_last', action='store_true', default=False)\n", + "parser.add_argument('--save_outputs', action='store_true', default=False)\n", + "parser.add_argument('--no_group_logging', action='store_true', default=False)\n", + "parser.add_argument('--val_metric_decreasing', action='store_true', default=False)\n", + "parser.add_argument('--use_wandb', action='store_true', default=False)\n", + "parser.add_argument('--progress_bar', action='store_true', default=False)\n", + "parser.add_argument('--resume', default=False, action='store_true')\n", + "parser.add_argument('--eval_only', default=False, action='store_true')\n", + "\n", + "args = parser.parse_args(args_kw)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# get_input (idx)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name '_metadata_df' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0midx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mthis_metadata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_metadata_df\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miloc\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mitime\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mflank_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m400\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mNameError\u001b[0m: name '_metadata_df' is not defined" + ] + } + ], + "source": [ + "idx = 3\n", + "this_metadata = _metadata_df.iloc[idx, :]\n", + "\n", + "itime = time.time()\n", + "flank_size = 400\n", + "interval_start = this_metadata['start'] - flank_size\n", + "interval_end = this_metadata['stop'] + flank_size\n", + "dnase_this = _dnase_allcelltypes[this_metadata['celltype']][this_metadata['chr']][interval_start:interval_end]\n", + "seq_this = _seq_bp[this_metadata['chr']][interval_start:interval_end]\n", + "data = np.column_stack([seq_this, dnase_this])\n", + "# print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.028102874755859375\n" + ] + } + ], + "source": [ + "itime = time.time()\n", + "metadata_array = torch.stack(\n", + " (torch.LongTensor(chr_ints), \n", + " torch.LongTensor(celltype_ints), \n", + " _y_array),\n", + " dim=1)\n", + "print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [ + { + "ename": "ModuleNotFoundError", + "evalue": "No module named 'torch_scatter'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m#data.shape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mwilds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata_loaders\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mget_train_loader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mget_eval_loader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m~/dr_benchmark/wilds/common/data_loaders.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mDataLoader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msampler\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mWeightedRandomSampler\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mSubsetRandomSampler\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mwilds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mget_counts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msplit_into_groups\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mIPython\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/dr_benchmark/wilds/common/utils.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch_scatter\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnumpy\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mSubset\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mpandas\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapi\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtypes\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mCategoricalDtype\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'torch_scatter'" + ] + } + ], + "source": [ + "#data.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 157, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "4600" + ] + }, + "execution_count": 157, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data.shape\n", + "interval_end\n", + "# itime = time.time()\n", + "# np.save(os.path.join(data_dir, 'stmp.npy'), sa)\n", + "# print(time.time() - itime)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Run training experiment" + ] + }, + { + "cell_type": "code", + "execution_count": 167, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'python3 examples/run_expt.py -d encodeTFBS --algorithm ERM --model densenet121 --split_scheme standard --groupby_fields hospital --loss_function cross_entropy --optimizer SGD --lr 0.0001 --batch_size 32 --weight_decay 0 --n_epochs 10 --scheduler ReduceLROnPlateau --scheduler_metric_split val --scheduler_metric_name acc_avg --log_dir log --log_every 50 --save_step 1000 --save_best --save_last --seed 0 --evaluate_all_splits --root_dir ROOTDIR'" + ] + }, + "execution_count": 167, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cmdstr = \"python3 examples/run_expt.py -d encodeTFBS --algorithm ERM --model densenet121 --split_scheme standard --groupby_fields hospital --loss_function cross_entropy\"\n", + "cmdstr += \" \"\n", + "cmdstr += \"--optimizer SGD --lr 0.0001 --batch_size 32 --weight_decay 0 --n_epochs 10 --scheduler ReduceLROnPlateau --scheduler_metric_split val --scheduler_metric_name acc_avg\"\n", + "cmdstr += \" \"\n", + "cmdstr += \"--log_dir log --log_every 50 --save_step 1000 --save_best --save_last --seed 0 --evaluate_all_splits --root_dir ROOTDIR\"\n", + "cmdstr" + ] + }, + { + "cell_type": "code", + "execution_count": 164, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name '_metadata_array' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0m_metadata_array\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mNameError\u001b[0m: name '_metadata_array' is not defined" + ] + } + ], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 165, + "metadata": {}, + "outputs": [ + { + "ename": "ModuleNotFoundError", + "evalue": "No module named 'torch_scatter'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minsert\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'..'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 16\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mwilds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata_loaders\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mget_train_loader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mget_eval_loader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 17\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mwilds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrouper\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mCombinatorialGrouper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mwilds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mget_counts\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/dr_benchmark/wilds/common/data_loaders.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mDataLoader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msampler\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mWeightedRandomSampler\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mSubsetRandomSampler\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mwilds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mget_counts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msplit_into_groups\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mIPython\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/dr_benchmark/wilds/common/utils.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch_scatter\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnumpy\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mSubset\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mpandas\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapi\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtypes\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mCategoricalDtype\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'torch_scatter'" + ] + } + ], + "source": [ + "import os, csv\n", + "import time\n", + "import argparse\n", + "import IPython\n", + "import pandas as pd\n", + "import torch\n", + "import torch.nn as nn\n", + "import torchvision\n", + "import sys\n", + "from collections import defaultdict\n", + "# torch.multiprocessing.set_sharing_strategy('file_system')\n", + "\n", + "# TODO: Replace this once we make wilds into an installed package\n", + "sys.path.insert(1, os.path.join(sys.path[0], '..'))\n", + "\n", + "from wilds.common.data_loaders import get_train_loader, get_eval_loader\n", + "from wilds.common.grouper import CombinatorialGrouper\n", + "from wilds.common.utils import get_counts\n", + "\n", + "from models.model_attributes import model_attributes\n", + "from utils import set_seed, Logger, BatchLogger, log_args, ParseKwargs, load\n", + "from train import train, evaluate\n", + "from data import dataset_attributes\n", + "from optimizer import optimizer_attributes\n", + "from scheduler import scheduler_attributes\n", + "from loss import losses\n", + "from utils import log_group_data\n", + "from algorithms.constructors import algorithm_constructors" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from examples.models.model_attributes import model_attributes" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "ename": "ModuleNotFoundError", + "evalue": "No module named 'utils'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel_attributes\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mmodel_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mset_seed\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mLogger\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mCSVBatchLogger\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlog_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mParseKwargs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mload\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 21\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 22\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mdataset_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0moptimizer_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/dr_benchmark/examples/train.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0msave\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mIPython\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'utils'" + ] + } + ], + "source": [ + "def initialize_algorithm(args, datasets, train_grouper):\n", + " train_dataset = datasets['train']['dataset']\n", + " train_loader = datasets['train']['loader']\n", + "\n", + " # Configure the final layer of the networks used\n", + " # The code below are defaults. Edit this if you need special config for your model.\n", + " if (train_dataset.is_classification) and (train_dataset.y_size == 1):\n", + " # For single-task classification, we have one output per class\n", + " d_out = train_dataset.n_classes\n", + " elif (train_dataset.is_classification) and (train_dataset.y_size > 1) and (train_dataset.n_classes == 2):\n", + " # For multi-task binary classification (each output is the logit for each binary class)\n", + " d_out = train_dataset.y_size\n", + " elif (not train_dataset.is_classification):\n", + " # For regression, we have one output per target dimension\n", + " d_out = train_dataset.y_size\n", + " else:\n", + " raise RuntimeError('d_out not defined.')\n", + " \n", + "\n", + " # Sanity checking input args\n", + " if args.algorithm == 'groupDRO':\n", + " assert args.train_loader_kwargs['uniform_over_groups']\n", + " elif args.algorithm in ['deepCORAL', 'IRM']:\n", + " assert args.train_loader == 'group'\n", + " assert args.train_loader_kwargs['uniform_over_groups']\n", + " assert args.train_loader_kwargs['distinct_groups']\n", + "\n", + " # Other config\n", + " n_train_steps = len(train_loader) * args.n_epochs\n", + "# prediction_fn = dataset_attributes[args.dataset]['prediction_fn']\n", + " loss = losses[args.loss_function]\n", + " metric = dataset_attributes[args.dataset]['metric']\n", + " train_g = train_grouper.metadata_to_group(train_dataset.metadata_array)\n", + " is_group_in_train = get_counts(train_g, train_grouper.n_groups) > 0\n", + " algorithm_constructor = algorithm_constructors[args.algorithm]\n", + " algorithm = algorithm_constructor(\n", + " args=args,\n", + " d_out=d_out,\n", + " grouper=train_grouper,\n", + " loss=loss,\n", + " metric=metric,\n", + " n_train_steps=n_train_steps,\n", + " is_group_in_train=is_group_in_train)\n", + " return algorithm" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def main():\n", + " parser = argparse.ArgumentParser()\n", + "\n", + " # Dataset\n", + " parser.add_argument('-d', '--dataset', choices=dataset_attributes.keys(), required=True)\n", + " parser.add_argument('--split_scheme', default='standard',\n", + " help='Identifies how the train/val/test split is constructed. Choices are dataset-specific.')\n", + " parser.add_argument('--dataset_kwargs', nargs='*', action=ParseKwargs, default={})\n", + " parser.add_argument('--root_dir', default=None, required=True,\n", + " help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).')\n", + " parser.add_argument('--download', default=False, action='store_true',\n", + " help='If true, tries to downloads the dataset if it does not exist in root_dir.')\n", + " parser.add_argument('--frac', type=float, default=1.0,\n", + " help='Convenience parameter that scales all dataset splits down to the specified fraction, for development purposes.')\n", + "\n", + " # Loaders\n", + " parser.add_argument('--train_loader', choices=['standard', 'group'], default='standard')\n", + " parser.add_argument('--train_loader_kwargs', nargs='*', action=ParseKwargs, default={})\n", + " parser.add_argument('--eval_loader', choices=['standard', 'group'], default='standard')\n", + " parser.add_argument('--batch_size', type=int, default=32)\n", + " parser.add_argument('--no_pin_memory', action='store_true') # TODO: put as loader_kwargs\n", + " parser.add_argument('--num_workers', type=int, default=4) # TODO: put as loader kwargs\n", + "\n", + " # Model\n", + " parser.add_argument(\n", + " '--model',\n", + " choices=model_attributes.keys(),\n", + " default='resnet50')\n", + " parser.add_argument('--model_kwargs', nargs='*', action=ParseKwargs, default={},\n", + " help='keyword arguments for model initialization passed as key1=value1 key2=value2')\n", + " parser.add_argument('--train_from_scratch', action='store_true', default=False)\n", + "\n", + " # Algorithm and objective\n", + " parser.add_argument('--algorithm', required=True, choices=algorithm_constructors.keys())\n", + " parser.add_argument('--algorithm_kwargs', nargs='*', action=ParseKwargs, default={})\n", + " parser.add_argument('--groupby_fields', nargs='+', default=None)\n", + " parser.add_argument('--loss_function', required=True, choices = losses.keys()) #TODO: make default\n", + " parser.add_argument('--val_metric', default=None)\n", + "\n", + " # Optimization\n", + " parser.add_argument('--n_epochs', type=int, default=4)\n", + " parser.add_argument('--optimizer', default=None, choices=optimizer_attributes.keys())\n", + " parser.add_argument('--lr', type=float, required=True)\n", + " parser.add_argument('--weight_decay', type=float, required=True)\n", + " parser.add_argument('--optimizer_kwargs', nargs='*', action=ParseKwargs, default={})\n", + " parser.add_argument('--scheduler', default=None, choices=scheduler_attributes.keys())\n", + " parser.add_argument('--scheduler_kwargs', nargs='*', action=ParseKwargs, default={})\n", + " parser.add_argument('--scheduler_metric_split', choices=['train', 'val'], default='val')\n", + " parser.add_argument('--scheduler_metric_name')\n", + "\n", + " # Evaluation\n", + " parser.add_argument('--evaluate_all_splits', action='store_true', default=False)\n", + " parser.add_argument('--additional_eval_splits', nargs='+', default=[])\n", + "\n", + " # Misc\n", + " parser.add_argument('--device', type=int, default=0)\n", + " parser.add_argument('--seed', type=int, default=0)\n", + " parser.add_argument('--log_dir', default='./logs')\n", + " parser.add_argument('--log_every', default=50, type=int)\n", + " parser.add_argument('--save_step', type=int, default=None)\n", + " parser.add_argument('--save_best', action='store_true', default=False)\n", + " parser.add_argument('--save_last', action='store_true', default=False)\n", + " parser.add_argument('--save_outputs', action='store_true', default=False)\n", + " parser.add_argument('--no_group_logging', action='store_true', default=False)\n", + " parser.add_argument('--val_metric_decreasing', action='store_true', default=False)\n", + " parser.add_argument('--use_wandb', action='store_true', default=False)\n", + " parser.add_argument('--progress_bar', action='store_true', default=False)\n", + " parser.add_argument('--resume', default=False, action='store_true')\n", + " parser.add_argument('--eval_only', default=False, action='store_true')\n", + "\n", + " args = parser.parse_args()\n", + "\n", + " # set device\n", + " args.device = torch.device(\"cuda:\" + str(args.device)) if torch.cuda.is_available() else torch.device(\"cpu\")\n", + "\n", + " # Set defaults\n", + " if args.groupby_fields is None:\n", + " args.no_group_logging = True\n", + " if args.val_metric is None:\n", + " args.val_metric = dataset_attributes[args.dataset]['val_metric']\n", + "\n", + " ## Initialize logs\n", + " if os.path.exists(args.log_dir) and args.resume:\n", + " resume=True\n", + " mode='a'\n", + " else:\n", + " resume=False\n", + " mode='w'\n", + " if not os.path.exists(args.log_dir):\n", + " os.makedirs(args.log_dir)\n", + " logger = Logger(os.path.join(args.log_dir, 'log.txt'), mode)\n", + "\n", + " # Record args\n", + " log_args(args, logger)\n", + "\n", + " # Set random seed\n", + " set_seed(args.seed)\n", + "\n", + " # Data\n", + " full_dataset = dataset_attributes[args.dataset]['constructor'](\n", + " root_dir=args.root_dir,\n", + " download=args.download,\n", + " split_scheme=args.split_scheme,\n", + " **args.dataset_kwargs)\n", + "\n", + " # To implement data augmentation (i.e., have different transforms\n", + " # at training time vs. test time), modify these two lines:\n", + " train_transform = dataset_attributes[args.dataset]['transform'](args.model)\n", + " if dataset_attributes[args.dataset].get('eval_transform') is None:\n", + " eval_transform = dataset_attributes[args.dataset]['transform'](args.model)\n", + " else:\n", + " eval_transform = dataset_attributes[args.dataset]['eval_transform'](args.model)\n", + "\n", + " train_grouper = CombinatorialGrouper(\n", + " dataset=full_dataset,\n", + " groupby_fields=args.groupby_fields)\n", + "\n", + " datasets = defaultdict(dict)\n", + " for split in full_dataset.split_dict.keys():\n", + " if split=='train':\n", + " transform = train_transform\n", + " verbose = True\n", + " elif split == 'val':\n", + " transform = eval_transform\n", + " verbose = True\n", + " else:\n", + " transform = eval_transform\n", + " verbose = False\n", + " # Get subset\n", + " datasets[split]['dataset'] = full_dataset.get_subset(\n", + " split,\n", + " frac=args.frac,\n", + " transform=transform)\n", + "\n", + " # Get loader\n", + " shared_loader_kwargs = {\n", + " 'num_workers': args.num_workers,\n", + " 'pin_memory': not args.no_pin_memory,\n", + " 'batch_size': args.batch_size,\n", + " 'collate_fn': dataset_attributes[args.dataset]['collate']\n", + " }\n", + "\n", + " if split == 'train':\n", + " datasets[split]['loader'] = get_train_loader(\n", + " loader=args.train_loader,\n", + " dataset=datasets[split]['dataset'],\n", + " grouper=train_grouper,\n", + " train_loader_kwargs=args.train_loader_kwargs,\n", + " **shared_loader_kwargs)\n", + " else:\n", + " datasets[split]['loader'] = get_eval_loader(\n", + " loader=args.eval_loader,\n", + " dataset=datasets[split]['dataset'],\n", + " grouper=train_grouper,\n", + " **shared_loader_kwargs)\n", + "\n", + " # Set fields\n", + " datasets[split]['split'] = split\n", + " datasets[split]['name'] = full_dataset.split_names[split]\n", + " datasets[split]['verbose'] = verbose\n", + " # Loggers\n", + " # Loggers\n", + " datasets[split]['eval_logger'] = BatchLogger(\n", + " os.path.join(args.log_dir, f'{split}_eval.csv'), mode=mode, use_wandb=args.use_wandb)\n", + " datasets[split]['algo_logger'] = BatchLogger(\n", + " os.path.join(args.log_dir, f'{split}_algo.csv'), mode=mode, use_wandb=args.use_wandb)\n", + "\n", + " if args.use_wandb:\n", + " initialize_wandb(args)\n", + "\n", + " # Logging dataset info\n", + " if args.no_group_logging and full_dataset.is_classification and full_dataset.y_size==1:\n", + " log_grouper = CombinatorialGrouper(\n", + " dataset=full_dataset,\n", + " groupby_fields=['y'])\n", + " elif args.no_group_logging:\n", + " log_grouper = None\n", + " else:\n", + " log_grouper = train_grouper\n", + " log_group_data(args, datasets, log_grouper, logger)\n", + "\n", + " ## Initialize algorithm\n", + " algorithm = initialize_algorithm(args, datasets, train_grouper)\n", + "\n", + " if not args.eval_only:\n", + " ## Load saved results if resuming\n", + " resume_success = False\n", + " if resume:\n", + " save_path = os.path.join(args.log_dir, 'last_model.pth')\n", + " if not os.path.exists(save_path):\n", + " epochs = [\n", + " int(file.split('_')[0])\n", + " for file in os.listdir(args.log_dir) if file.endswith('.pth')]\n", + " if len(epochs) > 0:\n", + " latest_epoch = max(epochs)\n", + " save_path = os.path.join(args.log_dir, f'{latest_epoch}_model.pth')\n", + " try:\n", + " prev_epoch, best_val_metric = load(algorithm, save_path)\n", + " epoch_offset = prev_epoch + 1\n", + " logger.write(f'Resuming from epoch {epoch_offset} with best val metric {best_val_metric}')\n", + " resume_success = True\n", + " except FileNotFoundError:\n", + " pass\n", + "\n", + " if resume_success == False:\n", + " epoch_offset=0\n", + " best_val_metric=None\n", + "\n", + "\n", + " train(algorithm,\n", + " datasets,\n", + " logger,\n", + " args,\n", + " epoch_offset=epoch_offset,\n", + " best_val_metric=best_val_metric)\n", + " else:\n", + " best_model_path = os.path.join(args.log_dir, 'best_model.pth')\n", + " best_epoch, best_val_metric = load(algorithm, best_model_path)\n", + " evaluate(algorithm, datasets, best_epoch, logger)\n", + "\n", + " logger.close()\n", + " for split in datasets:\n", + " datasets[split]['eval_logger'].close()\n", + " datasets[split]['algo_logger'].close()\n", + "\n", + "if __name__=='__main__':\n", + " main()\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 2", + "language": "python", + "name": "python2" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/dataset_preprocessing/encode-tfbs/sandbox_data.ipynb b/dataset_preprocessing/encode-tfbs/sandbox_data.ipynb new file mode 100644 index 00000000..b2e74829 --- /dev/null +++ b/dataset_preprocessing/encode-tfbs/sandbox_data.ipynb @@ -0,0 +1,952 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Initialize dataset object" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "57.5289368629\n", + "65.2459537983\n" + ] + } + ], + "source": [ + "import numpy as np, pandas as pd, os, time\n", + "import torch, torchvision\n", + "\n", + "data_dir = '/oak/stanford/groups/akundaje/abalsubr/DREAM/wilds/codalab_archive/'\n", + "tf = 'MAX'\n", + "itime = time.time()\n", + "train_chr = pd.read_csv(os.path.join(data_dir, 'labels/{}.train.labels.tsv.gz'.format(tf)), sep='\\t')\n", + "print(time.time() - itime)\n", + "val_chr = pd.read_csv(os.path.join(data_dir, 'labels/{}.val.labels.tsv.gz'.format(tf)), sep='\\t')\n", + "print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']\n", + "val_celltype = ['A549']\n", + "test_celltype = ['GM12878']\n", + "all_celltypes = train_celltypes + val_celltype + test_celltype\n", + "\n", + "metadata_map = {}\n", + "metadata_map['chr'] = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX']\n", + "metadata_map['celltype'] = all_celltypes\n", + "\n", + "_split_dict = {\n", + " 'train': 0,\n", + " 'val-id': 1,\n", + " 'test': 2,\n", + " 'val-ood': 3\n", + "}\n", + "_split_names = {\n", + " 'train': 'Train',\n", + " 'val-id': 'Validation (ID)',\n", + " 'test': 'Test',\n", + " 'val-ood': 'Validation (OOD)',\n", + "}\n", + "_split_scheme = 'standard'" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.0467748641968\n", + "('chr1', 4.52302885055542)\n", + "('chr2', 8.645489931106567)\n", + "('chr3', 11.959153890609741)\n", + "('chr4', 15.15813684463501)\n", + "('chr5', 18.22238802909851)\n", + "('chr6', 21.19420099258423)\n", + "('chr7', 23.940655946731567)\n", + "('chr8', 26.415233850479126)\n", + "('chr9', 28.833614826202393)\n", + "('chr10', 31.08920383453369)\n", + "('chr11', 33.37020301818848)\n", + "('chr12', 35.98973989486694)\n", + "('chr13', 37.88540601730347)\n", + "('chr14', 39.68082284927368)\n", + "('chr15', 41.242313861846924)\n", + "('chr16', 42.74874496459961)\n", + "('chr17', 44.12280797958374)\n", + "('chr18', 45.46893382072449)\n", + "('chr19', 46.50577902793884)\n", + "('chr20', 47.59563183784485)\n", + "('chr21', 48.31779384613037)\n", + "('chr22', 49.17265295982361)\n", + "('chrX', 51.75806999206543)\n", + "('H1-hESC', 25.880441904067993)\n", + "('HCT116', 50.130937814712524)\n", + "('HeLa-S3', 75.29559993743896)\n", + "('HepG2', 102.25979495048523)\n", + "('K562', 128.43050694465637)\n", + "('A549', 154.80679488182068)\n", + "('GM12878', 182.0279529094696)\n" + ] + } + ], + "source": [ + "itime = time.time()\n", + "sequence_filename = os.path.join(data_dir, 'sequence.npz')\n", + "seq_arr = np.load(sequence_filename)\n", + "print(time.time() - itime)\n", + "\n", + "itime = time.time()\n", + "_seq_bp = {}\n", + "for chrom in seq_arr:\n", + " _seq_bp[chrom] = seq_arr[chrom]\n", + " print(chrom, time.time() - itime)\n", + "\n", + "itime = time.time()\n", + "_dnase_allcelltypes = {}\n", + "for ct in all_celltypes:\n", + " dnase_filename = os.path.join(data_dir, '{}_dnase.npz'.format(ct))\n", + " dnase_npz_file = np.load(dnase_filename)\n", + " _dnase_allcelltypes[ct] = {}\n", + " for chrom in _seq_bp:\n", + " _dnase_allcelltypes[ct][chrom] = dnase_npz_file[chrom]\n", + " print(ct, time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'all_df' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# len(_dnase_allcelltypes)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mall_df\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mNameError\u001b[0m: name 'all_df' is not defined" + ] + } + ], + "source": [ + "# len(_dnase_allcelltypes)\n", + "all_df" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "ename": "AttributeError", + "evalue": "'module' object has no attribute 'isin'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mtr_chrs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m'chr2'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr9'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr11'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mte_chrs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m'chr1'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr8'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr21'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mtraining_df\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'chr'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtr_chrs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mval_df\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mval_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mval_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'chr'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mte_chrs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mall_df\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtraining_df\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_df\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mAttributeError\u001b[0m: 'module' object has no attribute 'isin'" + ] + } + ], + "source": [ + "tr_chrs = ['chr2', 'chr9', 'chr11']\n", + "te_chrs = ['chr1', 'chr8', 'chr21']\n", + "training_df = train_chr[np.isin(train_chr['chr'], tr_chrs)]\n", + "val_df = val_chr[np.isin(val_chr['chr'], te_chrs)]\n", + "all_df = pd.concat([training_df, val_df])\n", + "\n", + "#filter_msk = all_df['start'] >= 0\n", + "filter_msk = all_df['start']%1000 == 0\n", + "all_df = all_df[filter_msk]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "itime = time.time()\n", + "pd_list = []\n", + "for ct in all_celltypes:\n", + " tc_chr = all_df[['chr', 'start', 'stop', ct]]\n", + " tc_chr.columns = ['chr', 'start', 'stop', 'y']\n", + " tc_chr['celltype'] = ct\n", + " pd_list.append(tc_chr)\n", + "metadata_df = pd.concat(pd_list)\n", + "print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "itime = time.time()\n", + "y_array = metadata_df['y'].replace({'U': 0, 'B': 1, 'A': -1}).values\n", + "non_ambig_mask = (y_array != -1)\n", + "metadata_df['y'] = y_array\n", + "_metadata_df = metadata_df[non_ambig_mask]\n", + "_y_array = torch.LongTensor(y_array[non_ambig_mask])\n", + "print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "itime = time.time()\n", + "chr_ints = _metadata_df['chr'].replace(dict( [(y, x) for x, y in enumerate(metadata_map['chr'])] )).values\n", + "celltype_ints = _metadata_df['celltype'].replace(dict( [(y, x) for x, y in enumerate(metadata_map['celltype'])] )).values\n", + "print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_chr_mask = np.isin(_metadata_df['chr'], tr_chrs)\n", + "val_chr_mask = np.isin(_metadata_df['chr'], te_chrs)\n", + "train_celltype_mask = np.isin(_metadata_df['celltype'], train_celltypes)\n", + "val_celltype_mask = np.isin(_metadata_df['celltype'], val_celltype)\n", + "test_celltype_mask = np.isin(_metadata_df['celltype'], test_celltype)\n", + "\n", + "split_array = -1*np.ones(_metadata_df.shape[0]).astype(int)\n", + "split_array[np.logical_and(train_chr_mask, train_celltype_mask)] = _split_dict['train']\n", + "split_array[np.logical_and(val_chr_mask, test_celltype_mask)] = _split_dict['test']\n", + "split_array[np.logical_and(val_chr_mask, val_celltype_mask)] = _split_dict['val-ood']\n", + "split_array[np.logical_and(val_chr_mask, train_celltype_mask)] = _split_dict['val-id']\n", + "_metadata_df['split'] = split_array\n", + "_split_array = split_array" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "ename": "ImportError", + "evalue": "No module named data", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mDataLoader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mdataset_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mImportError\u001b[0m: No module named data" + ] + } + ], + "source": [ + "from torch.utils.data import DataLoader\n", + "from data import dataset_attributes" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "from PIL import Image\n", + "import argparse\n", + "class ParseKwargs(argparse.Action):\n", + " def __call__(self, parser, namespace, values, option_string=None):\n", + " setattr(namespace, self.dest, dict())\n", + " for value in values:\n", + " key, value_str = value.split('=')\n", + " if value_str.replace('-','').isnumeric():\n", + " processed_val = int(value_str)\n", + " elif value_str.replace('-','').replace('.','').isnumeric():\n", + " processed_val = float(value_str)\n", + " elif value_str in ['True', 'true']:\n", + " processed_val = True\n", + " elif value_str in ['False', 'false']:\n", + " processed_val = False\n", + " else:\n", + " processed_val = value_str\n", + " getattr(namespace, self.dest)[key] = processed_val" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'algorithm_constructors' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[0;31m# Algorithm and objective\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 33\u001b[0;31m \u001b[0mparser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_argument\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'--algorithm'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrequired\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mchoices\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0malgorithm_constructors\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 34\u001b[0m \u001b[0mparser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_argument\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'--algorithm_kwargs'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnargs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'*'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maction\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mParseKwargs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdefault\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 35\u001b[0m \u001b[0mparser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_argument\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'--groupby_fields'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnargs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'+'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdefault\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mNameError\u001b[0m: name 'algorithm_constructors' is not defined" + ] + } + ], + "source": [ + "ROOTDIR = '/oak/stanford/groups/akundaje/abalsubr/wilds_other'\n", + "args_kw = \"-d camelyon17 --algorithm ERM --model densenet121 --split_scheme standard --groupby_fields hospital --loss_function cross_entropy --optimizer SGD --lr 0.0001 --batch_size 32 --weight_decay 0 --n_epochs 10 --scheduler ReduceLROnPlateau --scheduler_metric_split val --scheduler_metric_name acc_avg --log_dir log --log_every 50 --save_step 1000 --save_best --save_last --seed 0 --evaluate_all_splits --root_dir {}\".format(\n", + " ROOTDIR).split()\n", + "\n", + "parser = argparse.ArgumentParser()\n", + "\n", + "# Dataset\n", + "parser.add_argument('-d', '--dataset', choices=['encodeTFBS', 'amazon', 'camelyon17', 'celebA', 'civilcomments', 'iwildcam', 'waterbirds', 'yelp', 'poverty', 'fmow', 'ogbg-molpcba'], required=True)\n", + "parser.add_argument('--split_scheme', default='standard',\n", + " help='Identifies how the train/val/test split is constructed. Choices are dataset-specific.')\n", + "parser.add_argument('--dataset_kwargs', nargs='*', action=ParseKwargs, default={})\n", + "parser.add_argument('--root_dir', default=None, required=True,\n", + " help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).')\n", + "parser.add_argument('--download', default=False, action='store_true',\n", + " help='If true, tries to downloads the dataset if it does not exist in root_dir.')\n", + "parser.add_argument('--frac', type=float, default=1.0,\n", + " help='Convenience parameter that scales all dataset splits down to the specified fraction, for development purposes.')\n", + "\n", + "# Loaders\n", + "parser.add_argument('--train_loader', choices=['standard', 'group'], default='standard')\n", + "parser.add_argument('--train_loader_kwargs', nargs='*', action=ParseKwargs, default={})\n", + "parser.add_argument('--eval_loader', choices=['standard', 'group'], default='standard')\n", + "parser.add_argument('--batch_size', type=int, default=32)\n", + "parser.add_argument('--no_pin_memory', action='store_true') # TODO: put as loader_kwargs\n", + "parser.add_argument('--num_workers', type=int, default=4) # TODO: put as loader kwargs\n", + "\n", + "# Model\n", + "parser.add_argument(\n", + " '--model',\n", + " choices=['bert-base-uncased', 'inception_v3', 'densenet121', 'wideresnet50', 'resnet50', 'gin-virtual', 'resnet18_ms'],\n", + " default='resnet50')\n", + "parser.add_argument('--model_kwargs', nargs='*', action=ParseKwargs, default={},\n", + " help='keyword arguments for model initialization passed as key1=value1 key2=value2')\n", + "parser.add_argument('--train_from_scratch', action='store_true', default=False)\n", + "\n", + "# Algorithm and objective\n", + "parser.add_argument('--algorithm', required=True, choices=algorithm_constructors.keys())\n", + "parser.add_argument('--algorithm_kwargs', nargs='*', action=ParseKwargs, default={})\n", + "parser.add_argument('--groupby_fields', nargs='+', default=None)\n", + "parser.add_argument('--loss_function', required=True, choices = losses.keys()) #TODO: make default\n", + "parser.add_argument('--val_metric', default=None)\n", + "\n", + "# Optimization\n", + "parser.add_argument('--n_epochs', type=int, default=4)\n", + "parser.add_argument('--optimizer', default=None, choices=optimizer_attributes.keys())\n", + "parser.add_argument('--lr', type=float, required=True)\n", + "parser.add_argument('--weight_decay', type=float, required=True)\n", + "parser.add_argument('--optimizer_kwargs', nargs='*', action=ParseKwargs, default={})\n", + "parser.add_argument('--scheduler', default=None, choices=scheduler_attributes.keys())\n", + "parser.add_argument('--scheduler_kwargs', nargs='*', action=ParseKwargs, default={})\n", + "parser.add_argument('--scheduler_metric_split', choices=['train', 'val'], default='val')\n", + "parser.add_argument('--scheduler_metric_name')\n", + "\n", + "# Evaluation\n", + "parser.add_argument('--evaluate_all_splits', action='store_true', default=False)\n", + "parser.add_argument('--additional_eval_splits', nargs='+', default=[])\n", + "\n", + "# Misc\n", + "parser.add_argument('--device', type=int, default=0)\n", + "parser.add_argument('--seed', type=int, default=0)\n", + "parser.add_argument('--log_dir', default='./logs')\n", + "parser.add_argument('--log_every', default=50, type=int)\n", + "parser.add_argument('--save_step', type=int, default=None)\n", + "parser.add_argument('--save_best', action='store_true', default=False)\n", + "parser.add_argument('--save_last', action='store_true', default=False)\n", + "parser.add_argument('--save_outputs', action='store_true', default=False)\n", + "parser.add_argument('--no_group_logging', action='store_true', default=False)\n", + "parser.add_argument('--val_metric_decreasing', action='store_true', default=False)\n", + "parser.add_argument('--use_wandb', action='store_true', default=False)\n", + "parser.add_argument('--progress_bar', action='store_true', default=False)\n", + "parser.add_argument('--resume', default=False, action='store_true')\n", + "parser.add_argument('--eval_only', default=False, action='store_true')\n", + "\n", + "args = parser.parse_args(args_kw)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# get_input (idx)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name '_metadata_df' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0midx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mthis_metadata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_metadata_df\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miloc\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mitime\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mflank_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m400\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mNameError\u001b[0m: name '_metadata_df' is not defined" + ] + } + ], + "source": [ + "idx = 3\n", + "this_metadata = _metadata_df.iloc[idx, :]\n", + "\n", + "itime = time.time()\n", + "flank_size = 400\n", + "interval_start = this_metadata['start'] - flank_size\n", + "interval_end = this_metadata['stop'] + flank_size\n", + "dnase_this = _dnase_allcelltypes[this_metadata['celltype']][this_metadata['chr']][interval_start:interval_end]\n", + "seq_this = _seq_bp[this_metadata['chr']][interval_start:interval_end]\n", + "data = np.column_stack([seq_this, dnase_this])\n", + "# print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.028102874755859375\n" + ] + } + ], + "source": [ + "itime = time.time()\n", + "metadata_array = torch.stack(\n", + " (torch.LongTensor(chr_ints), \n", + " torch.LongTensor(celltype_ints), \n", + " _y_array),\n", + " dim=1)\n", + "print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [ + { + "ename": "ModuleNotFoundError", + "evalue": "No module named 'torch_scatter'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m#data.shape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mwilds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata_loaders\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mget_train_loader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mget_eval_loader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m~/dr_benchmark/wilds/common/data_loaders.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mDataLoader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msampler\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mWeightedRandomSampler\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mSubsetRandomSampler\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mwilds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mget_counts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msplit_into_groups\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mIPython\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/dr_benchmark/wilds/common/utils.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch_scatter\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnumpy\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mSubset\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mpandas\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapi\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtypes\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mCategoricalDtype\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'torch_scatter'" + ] + } + ], + "source": [ + "#data.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 157, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "4600" + ] + }, + "execution_count": 157, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data.shape\n", + "interval_end\n", + "# itime = time.time()\n", + "# np.save(os.path.join(data_dir, 'stmp.npy'), sa)\n", + "# print(time.time() - itime)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Run training experiment" + ] + }, + { + "cell_type": "code", + "execution_count": 167, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'python3 examples/run_expt.py -d encodeTFBS --algorithm ERM --model densenet121 --split_scheme standard --groupby_fields hospital --loss_function cross_entropy --optimizer SGD --lr 0.0001 --batch_size 32 --weight_decay 0 --n_epochs 10 --scheduler ReduceLROnPlateau --scheduler_metric_split val --scheduler_metric_name acc_avg --log_dir log --log_every 50 --save_step 1000 --save_best --save_last --seed 0 --evaluate_all_splits --root_dir ROOTDIR'" + ] + }, + "execution_count": 167, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cmdstr = \"python3 examples/run_expt.py -d encodeTFBS --algorithm ERM --model densenet121 --split_scheme standard --groupby_fields hospital --loss_function cross_entropy\"\n", + "cmdstr += \" \"\n", + "cmdstr += \"--optimizer SGD --lr 0.0001 --batch_size 32 --weight_decay 0 --n_epochs 10 --scheduler ReduceLROnPlateau --scheduler_metric_split val --scheduler_metric_name acc_avg\"\n", + "cmdstr += \" \"\n", + "cmdstr += \"--log_dir log --log_every 50 --save_step 1000 --save_best --save_last --seed 0 --evaluate_all_splits --root_dir ROOTDIR\"\n", + "cmdstr" + ] + }, + { + "cell_type": "code", + "execution_count": 164, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name '_metadata_array' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0m_metadata_array\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mNameError\u001b[0m: name '_metadata_array' is not defined" + ] + } + ], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 165, + "metadata": {}, + "outputs": [ + { + "ename": "ModuleNotFoundError", + "evalue": "No module named 'torch_scatter'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minsert\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'..'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 16\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mwilds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata_loaders\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mget_train_loader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mget_eval_loader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 17\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mwilds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrouper\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mCombinatorialGrouper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mwilds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mget_counts\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/dr_benchmark/wilds/common/data_loaders.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mDataLoader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msampler\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mWeightedRandomSampler\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mSubsetRandomSampler\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mwilds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mget_counts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msplit_into_groups\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mIPython\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/dr_benchmark/wilds/common/utils.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch_scatter\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnumpy\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mSubset\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mpandas\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapi\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtypes\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mCategoricalDtype\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'torch_scatter'" + ] + } + ], + "source": [ + "import os, csv\n", + "import time\n", + "import argparse\n", + "import IPython\n", + "import pandas as pd\n", + "import torch\n", + "import torch.nn as nn\n", + "import torchvision\n", + "import sys\n", + "from collections import defaultdict\n", + "# torch.multiprocessing.set_sharing_strategy('file_system')\n", + "\n", + "# TODO: Replace this once we make wilds into an installed package\n", + "sys.path.insert(1, os.path.join(sys.path[0], '..'))\n", + "\n", + "from wilds.common.data_loaders import get_train_loader, get_eval_loader\n", + "from wilds.common.grouper import CombinatorialGrouper\n", + "from wilds.common.utils import get_counts\n", + "\n", + "from models.model_attributes import model_attributes\n", + "from utils import set_seed, Logger, BatchLogger, log_args, ParseKwargs, load\n", + "from train import train, evaluate\n", + "from data import dataset_attributes\n", + "from optimizer import optimizer_attributes\n", + "from scheduler import scheduler_attributes\n", + "from loss import losses\n", + "from utils import log_group_data\n", + "from algorithms.constructors import algorithm_constructors" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from examples.models.model_attributes import model_attributes" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "ename": "ModuleNotFoundError", + "evalue": "No module named 'utils'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel_attributes\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mmodel_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mset_seed\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mLogger\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mCSVBatchLogger\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlog_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mParseKwargs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mload\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 21\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 22\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mdataset_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0moptimizer_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/dr_benchmark/examples/train.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0msave\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mIPython\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'utils'" + ] + } + ], + "source": [ + "def initialize_algorithm(args, datasets, train_grouper):\n", + " train_dataset = datasets['train']['dataset']\n", + " train_loader = datasets['train']['loader']\n", + "\n", + " # Configure the final layer of the networks used\n", + " # The code below are defaults. Edit this if you need special config for your model.\n", + " if (train_dataset.is_classification) and (train_dataset.y_size == 1):\n", + " # For single-task classification, we have one output per class\n", + " d_out = train_dataset.n_classes\n", + " elif (train_dataset.is_classification) and (train_dataset.y_size > 1) and (train_dataset.n_classes == 2):\n", + " # For multi-task binary classification (each output is the logit for each binary class)\n", + " d_out = train_dataset.y_size\n", + " elif (not train_dataset.is_classification):\n", + " # For regression, we have one output per target dimension\n", + " d_out = train_dataset.y_size\n", + " else:\n", + " raise RuntimeError('d_out not defined.')\n", + " \n", + "\n", + " # Sanity checking input args\n", + " if args.algorithm == 'groupDRO':\n", + " assert args.train_loader_kwargs['uniform_over_groups']\n", + " elif args.algorithm in ['deepCORAL', 'IRM']:\n", + " assert args.train_loader == 'group'\n", + " assert args.train_loader_kwargs['uniform_over_groups']\n", + " assert args.train_loader_kwargs['distinct_groups']\n", + "\n", + " # Other config\n", + " n_train_steps = len(train_loader) * args.n_epochs\n", + "# prediction_fn = dataset_attributes[args.dataset]['prediction_fn']\n", + " loss = losses[args.loss_function]\n", + " metric = dataset_attributes[args.dataset]['metric']\n", + " train_g = train_grouper.metadata_to_group(train_dataset.metadata_array)\n", + " is_group_in_train = get_counts(train_g, train_grouper.n_groups) > 0\n", + " algorithm_constructor = algorithm_constructors[args.algorithm]\n", + " algorithm = algorithm_constructor(\n", + " args=args,\n", + " d_out=d_out,\n", + " grouper=train_grouper,\n", + " loss=loss,\n", + " metric=metric,\n", + " n_train_steps=n_train_steps,\n", + " is_group_in_train=is_group_in_train)\n", + " return algorithm" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def main():\n", + " parser = argparse.ArgumentParser()\n", + "\n", + " # Dataset\n", + " parser.add_argument('-d', '--dataset', choices=dataset_attributes.keys(), required=True)\n", + " parser.add_argument('--split_scheme', default='standard',\n", + " help='Identifies how the train/val/test split is constructed. Choices are dataset-specific.')\n", + " parser.add_argument('--dataset_kwargs', nargs='*', action=ParseKwargs, default={})\n", + " parser.add_argument('--root_dir', default=None, required=True,\n", + " help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).')\n", + " parser.add_argument('--download', default=False, action='store_true',\n", + " help='If true, tries to downloads the dataset if it does not exist in root_dir.')\n", + " parser.add_argument('--frac', type=float, default=1.0,\n", + " help='Convenience parameter that scales all dataset splits down to the specified fraction, for development purposes.')\n", + "\n", + " # Loaders\n", + " parser.add_argument('--train_loader', choices=['standard', 'group'], default='standard')\n", + " parser.add_argument('--train_loader_kwargs', nargs='*', action=ParseKwargs, default={})\n", + " parser.add_argument('--eval_loader', choices=['standard', 'group'], default='standard')\n", + " parser.add_argument('--batch_size', type=int, default=32)\n", + " parser.add_argument('--no_pin_memory', action='store_true') # TODO: put as loader_kwargs\n", + " parser.add_argument('--num_workers', type=int, default=4) # TODO: put as loader kwargs\n", + "\n", + " # Model\n", + " parser.add_argument(\n", + " '--model',\n", + " choices=model_attributes.keys(),\n", + " default='resnet50')\n", + " parser.add_argument('--model_kwargs', nargs='*', action=ParseKwargs, default={},\n", + " help='keyword arguments for model initialization passed as key1=value1 key2=value2')\n", + " parser.add_argument('--train_from_scratch', action='store_true', default=False)\n", + "\n", + " # Algorithm and objective\n", + " parser.add_argument('--algorithm', required=True, choices=algorithm_constructors.keys())\n", + " parser.add_argument('--algorithm_kwargs', nargs='*', action=ParseKwargs, default={})\n", + " parser.add_argument('--groupby_fields', nargs='+', default=None)\n", + " parser.add_argument('--loss_function', required=True, choices = losses.keys()) #TODO: make default\n", + " parser.add_argument('--val_metric', default=None)\n", + "\n", + " # Optimization\n", + " parser.add_argument('--n_epochs', type=int, default=4)\n", + " parser.add_argument('--optimizer', default=None, choices=optimizer_attributes.keys())\n", + " parser.add_argument('--lr', type=float, required=True)\n", + " parser.add_argument('--weight_decay', type=float, required=True)\n", + " parser.add_argument('--optimizer_kwargs', nargs='*', action=ParseKwargs, default={})\n", + " parser.add_argument('--scheduler', default=None, choices=scheduler_attributes.keys())\n", + " parser.add_argument('--scheduler_kwargs', nargs='*', action=ParseKwargs, default={})\n", + " parser.add_argument('--scheduler_metric_split', choices=['train', 'val'], default='val')\n", + " parser.add_argument('--scheduler_metric_name')\n", + "\n", + " # Evaluation\n", + " parser.add_argument('--evaluate_all_splits', action='store_true', default=False)\n", + " parser.add_argument('--additional_eval_splits', nargs='+', default=[])\n", + "\n", + " # Misc\n", + " parser.add_argument('--device', type=int, default=0)\n", + " parser.add_argument('--seed', type=int, default=0)\n", + " parser.add_argument('--log_dir', default='./logs')\n", + " parser.add_argument('--log_every', default=50, type=int)\n", + " parser.add_argument('--save_step', type=int, default=None)\n", + " parser.add_argument('--save_best', action='store_true', default=False)\n", + " parser.add_argument('--save_last', action='store_true', default=False)\n", + " parser.add_argument('--save_outputs', action='store_true', default=False)\n", + " parser.add_argument('--no_group_logging', action='store_true', default=False)\n", + " parser.add_argument('--val_metric_decreasing', action='store_true', default=False)\n", + " parser.add_argument('--use_wandb', action='store_true', default=False)\n", + " parser.add_argument('--progress_bar', action='store_true', default=False)\n", + " parser.add_argument('--resume', default=False, action='store_true')\n", + " parser.add_argument('--eval_only', default=False, action='store_true')\n", + "\n", + " args = parser.parse_args()\n", + "\n", + " # set device\n", + " args.device = torch.device(\"cuda:\" + str(args.device)) if torch.cuda.is_available() else torch.device(\"cpu\")\n", + "\n", + " # Set defaults\n", + " if args.groupby_fields is None:\n", + " args.no_group_logging = True\n", + " if args.val_metric is None:\n", + " args.val_metric = dataset_attributes[args.dataset]['val_metric']\n", + "\n", + " ## Initialize logs\n", + " if os.path.exists(args.log_dir) and args.resume:\n", + " resume=True\n", + " mode='a'\n", + " else:\n", + " resume=False\n", + " mode='w'\n", + " if not os.path.exists(args.log_dir):\n", + " os.makedirs(args.log_dir)\n", + " logger = Logger(os.path.join(args.log_dir, 'log.txt'), mode)\n", + "\n", + " # Record args\n", + " log_args(args, logger)\n", + "\n", + " # Set random seed\n", + " set_seed(args.seed)\n", + "\n", + " # Data\n", + " full_dataset = dataset_attributes[args.dataset]['constructor'](\n", + " root_dir=args.root_dir,\n", + " download=args.download,\n", + " split_scheme=args.split_scheme,\n", + " **args.dataset_kwargs)\n", + "\n", + " # To implement data augmentation (i.e., have different transforms\n", + " # at training time vs. test time), modify these two lines:\n", + " train_transform = dataset_attributes[args.dataset]['transform'](args.model)\n", + " if dataset_attributes[args.dataset].get('eval_transform') is None:\n", + " eval_transform = dataset_attributes[args.dataset]['transform'](args.model)\n", + " else:\n", + " eval_transform = dataset_attributes[args.dataset]['eval_transform'](args.model)\n", + "\n", + " train_grouper = CombinatorialGrouper(\n", + " dataset=full_dataset,\n", + " groupby_fields=args.groupby_fields)\n", + "\n", + " datasets = defaultdict(dict)\n", + " for split in full_dataset.split_dict.keys():\n", + " if split=='train':\n", + " transform = train_transform\n", + " verbose = True\n", + " elif split == 'val':\n", + " transform = eval_transform\n", + " verbose = True\n", + " else:\n", + " transform = eval_transform\n", + " verbose = False\n", + " # Get subset\n", + " datasets[split]['dataset'] = full_dataset.get_subset(\n", + " split,\n", + " frac=args.frac,\n", + " transform=transform)\n", + "\n", + " # Get loader\n", + " shared_loader_kwargs = {\n", + " 'num_workers': args.num_workers,\n", + " 'pin_memory': not args.no_pin_memory,\n", + " 'batch_size': args.batch_size,\n", + " 'collate_fn': dataset_attributes[args.dataset]['collate']\n", + " }\n", + "\n", + " if split == 'train':\n", + " datasets[split]['loader'] = get_train_loader(\n", + " loader=args.train_loader,\n", + " dataset=datasets[split]['dataset'],\n", + " grouper=train_grouper,\n", + " train_loader_kwargs=args.train_loader_kwargs,\n", + " **shared_loader_kwargs)\n", + " else:\n", + " datasets[split]['loader'] = get_eval_loader(\n", + " loader=args.eval_loader,\n", + " dataset=datasets[split]['dataset'],\n", + " grouper=train_grouper,\n", + " **shared_loader_kwargs)\n", + "\n", + " # Set fields\n", + " datasets[split]['split'] = split\n", + " datasets[split]['name'] = full_dataset.split_names[split]\n", + " datasets[split]['verbose'] = verbose\n", + " # Loggers\n", + " # Loggers\n", + " datasets[split]['eval_logger'] = BatchLogger(\n", + " os.path.join(args.log_dir, f'{split}_eval.csv'), mode=mode, use_wandb=args.use_wandb)\n", + " datasets[split]['algo_logger'] = BatchLogger(\n", + " os.path.join(args.log_dir, f'{split}_algo.csv'), mode=mode, use_wandb=args.use_wandb)\n", + "\n", + " if args.use_wandb:\n", + " initialize_wandb(args)\n", + "\n", + " # Logging dataset info\n", + " if args.no_group_logging and full_dataset.is_classification and full_dataset.y_size==1:\n", + " log_grouper = CombinatorialGrouper(\n", + " dataset=full_dataset,\n", + " groupby_fields=['y'])\n", + " elif args.no_group_logging:\n", + " log_grouper = None\n", + " else:\n", + " log_grouper = train_grouper\n", + " log_group_data(args, datasets, log_grouper, logger)\n", + "\n", + " ## Initialize algorithm\n", + " algorithm = initialize_algorithm(args, datasets, train_grouper)\n", + "\n", + " if not args.eval_only:\n", + " ## Load saved results if resuming\n", + " resume_success = False\n", + " if resume:\n", + " save_path = os.path.join(args.log_dir, 'last_model.pth')\n", + " if not os.path.exists(save_path):\n", + " epochs = [\n", + " int(file.split('_')[0])\n", + " for file in os.listdir(args.log_dir) if file.endswith('.pth')]\n", + " if len(epochs) > 0:\n", + " latest_epoch = max(epochs)\n", + " save_path = os.path.join(args.log_dir, f'{latest_epoch}_model.pth')\n", + " try:\n", + " prev_epoch, best_val_metric = load(algorithm, save_path)\n", + " epoch_offset = prev_epoch + 1\n", + " logger.write(f'Resuming from epoch {epoch_offset} with best val metric {best_val_metric}')\n", + " resume_success = True\n", + " except FileNotFoundError:\n", + " pass\n", + "\n", + " if resume_success == False:\n", + " epoch_offset=0\n", + " best_val_metric=None\n", + "\n", + "\n", + " train(algorithm,\n", + " datasets,\n", + " logger,\n", + " args,\n", + " epoch_offset=epoch_offset,\n", + " best_val_metric=best_val_metric)\n", + " else:\n", + " best_model_path = os.path.join(args.log_dir, 'best_model.pth')\n", + " best_epoch, best_val_metric = load(algorithm, best_model_path)\n", + " evaluate(algorithm, datasets, best_epoch, logger)\n", + "\n", + " logger.close()\n", + " for split in datasets:\n", + " datasets[split]['eval_logger'].close()\n", + " datasets[split]['algo_logger'].close()\n", + "\n", + "if __name__=='__main__':\n", + " main()\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 2", + "language": "python", + "name": "python2" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/sandbox_model.ipynb b/sandbox_model.ipynb new file mode 100644 index 00000000..c264d747 --- /dev/null +++ b/sandbox_model.ipynb @@ -0,0 +1,876 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Initialize dataset object" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "57.8772239685\n", + "66.8270189762\n" + ] + } + ], + "source": [ + "import numpy as np, pandas as pd, os, time, torch, torchvision\n", + "data_dir = '/oak/stanford/groups/akundaje/abalsubr/DREAM/wilds/codalab_archive/'\n", + "tf = 'MAX'\n", + "itime = time.time()\n", + "train_chr = pd.read_csv(os.path.join(data_dir, 'labels/{}.train.labels.tsv.gz'.format(tf)), sep='\\t')\n", + "print(time.time() - itime)\n", + "val_chr = pd.read_csv(os.path.join(data_dir, 'labels/{}.val.labels.tsv.gz'.format(tf)), sep='\\t')\n", + "print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']\n", + "val_celltype = ['A549']\n", + "test_celltype = ['GM12878']\n", + "all_celltypes = train_celltypes + val_celltype + test_celltype\n", + "\n", + "metadata_map = {}\n", + "metadata_map['chr'] = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX']\n", + "metadata_map['celltype'] = all_celltypes\n", + "\n", + "_split_dict = {\n", + " 'train': 0,\n", + " 'val-id': 1,\n", + " 'test': 2,\n", + " 'val-ood': 3\n", + "}\n", + "_split_names = {\n", + " 'train': 'Train',\n", + " 'val-id': 'Validation (ID)',\n", + " 'test': 'Test',\n", + " 'val-ood': 'Validation (OOD)'\n", + "}\n", + "_split_scheme = 'standard'" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "('H1-hESC', 25.299736976623535)\n", + "('HCT116', 49.68733310699463)\n", + "('HeLa-S3', 74.65905213356018)\n", + "('HepG2', 99.33112812042236)\n", + "('K562', 124.1327919960022)\n", + "('A549', 149.19999814033508)\n", + "('GM12878', 174.0277030467987)\n" + ] + } + ], + "source": [ + "itime = time.time()\n", + "sequence_filename = os.path.join(data_dir, 'sequence.npz')\n", + "seq_arr = np.load(sequence_filename)\n", + "print(time.time() - itime)\n", + "\n", + "itime = time.time()\n", + "_seq_bp = {}\n", + "for chrom in seq_arr:\n", + " _seq_bp[chrom] = seq_arr[chrom]\n", + " print(chrom, time.time() - itime)\n", + "itime = time.time()\n", + "_dnase_allcelltypes = {}\n", + "for ct in all_celltypes:\n", + " dnase_filename = os.path.join(data_dir, '{}_dnase.npz'.format(ct))\n", + " dnase_npz_file = np.load(dnase_filename)\n", + " _dnase_allcelltypes[ct] = {}\n", + " for chrom in _seq_bp:\n", + " _dnase_allcelltypes[ct][chrom] = dnase_npz_file[chrom]\n", + " print(ct, time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'A549': {'chr1': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr10': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr11': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr12': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr13': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr14': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr15': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr16': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr17': array([ 0.35986328, 0.35986328, 0.35986328, ..., 0. ,\n", + " 0. , 0. ], dtype=float16),\n", + " 'chr18': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr19': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr2': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr20': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr21': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr22': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr3': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr4': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr5': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr6': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr7': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr8': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr9': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chrX': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16)},\n", + " 'GM12878': {'chr1': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr10': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr11': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr12': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr13': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr14': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr15': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr16': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr17': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr18': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr19': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr2': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr20': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr21': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr22': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr3': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr4': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr5': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr6': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr7': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr8': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr9': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chrX': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16)},\n", + " 'H1-hESC': {'chr1': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr10': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr11': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr12': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr13': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr14': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr15': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr16': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr17': array([ 0.71972656, 0.71972656, 0.71972656, ..., 0. ,\n", + " 0. , 0. ], dtype=float16),\n", + " 'chr18': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr19': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr2': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr20': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr21': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr22': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr3': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr4': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr5': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr6': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr7': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr8': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr9': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chrX': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16)},\n", + " 'HCT116': {'chr1': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr10': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr11': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr12': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr13': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr14': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr15': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr16': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr17': array([ 0.80419922, 0.80419922, 0.80419922, ..., 0. ,\n", + " 0. , 0. ], dtype=float16),\n", + " 'chr18': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr19': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr2': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr20': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr21': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr22': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr3': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr4': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr5': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr6': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr7': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr8': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr9': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chrX': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16)},\n", + " 'HeLa-S3': {'chr1': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr10': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr11': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr12': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr13': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr14': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr15': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr16': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr17': array([ 0.71972656, 0.71972656, 0.71972656, ..., 0. ,\n", + " 0. , 0. ], dtype=float16),\n", + " 'chr18': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr19': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr2': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr20': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr21': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr22': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr3': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr4': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr5': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr6': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr7': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr8': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr9': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chrX': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16)},\n", + " 'HepG2': {'chr1': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr10': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr11': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr12': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr13': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr14': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr15': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr16': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr17': array([ 0.71972656, 0.71972656, 0.71972656, ..., 0. ,\n", + " 0. , 0. ], dtype=float16),\n", + " 'chr18': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr19': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr2': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr20': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr21': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr22': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr3': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr4': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr5': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr6': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr7': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr8': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr9': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chrX': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16)},\n", + " 'K562': {'chr1': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr10': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr11': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr12': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr13': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr14': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr15': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr16': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr17': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr18': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr19': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr2': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr20': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr21': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr22': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr3': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr4': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr5': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr6': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr7': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr8': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chr9': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", + " 'chrX': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16)}}" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "_dnase_allcelltypes" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "tr_chrs = ['chr2', 'chr9', 'chr11']\n", + "te_chrs = ['chr1', 'chr8', 'chr21']\n", + "training_df = train_chr[np.isin(train_chr['chr'], tr_chrs)]\n", + "val_df = val_chr[np.isin(val_chr['chr'], te_chrs)]\n", + "all_df = pd.concat([training_df, val_df])\n", + "\n", + "#filter_msk = all_df['start'] >= 0\n", + "filter_msk = all_df['start']%1000 == 0\n", + "all_df = all_df[filter_msk]" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/users/abalsubr/anaconda2/envs/scs3/lib/python3.6/site-packages/ipykernel_launcher.py:6: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy\n", + " \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1.659163236618042\n" + ] + } + ], + "source": [ + "itime = time.time()\n", + "pd_list = []\n", + "for ct in all_celltypes:\n", + " tc_chr = all_df[['chr', 'start', 'stop', ct]]\n", + " tc_chr.columns = ['chr', 'start', 'stop', 'y']\n", + " tc_chr['celltype'] = ct\n", + " pd_list.append(tc_chr)\n", + "metadata_df = pd.concat(pd_list)\n", + "print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3.0391879081726074\n" + ] + } + ], + "source": [ + "itime = time.time()\n", + "y_array = metadata_df['y'].replace({'U': 0, 'B': 1, 'A': -1}).values\n", + "non_ambig_mask = (y_array != -1)\n", + "metadata_df['y'] = y_array\n", + "_metadata_df = metadata_df[non_ambig_mask]\n", + "_y_array = torch.LongTensor(y_array[non_ambig_mask])\n", + "print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "12.390011310577393\n" + ] + } + ], + "source": [ + "itime = time.time()\n", + "chr_ints = _metadata_df['chr'].replace(dict( [(y, x) for x, y in enumerate(metadata_map['chr'])] )).values\n", + "celltype_ints = _metadata_df['celltype'].replace(dict( [(y, x) for x, y in enumerate(metadata_map['celltype'])] )).values\n", + "print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/users/abalsubr/anaconda2/envs/scs3/lib/python3.6/site-packages/ipykernel_launcher.py:12: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy\n", + " if sys.path[0] == '':\n" + ] + } + ], + "source": [ + "train_chr_mask = np.isin(_metadata_df['chr'], tr_chrs)\n", + "val_chr_mask = np.isin(_metadata_df['chr'], te_chrs)\n", + "train_celltype_mask = np.isin(_metadata_df['celltype'], train_celltypes)\n", + "val_celltype_mask = np.isin(_metadata_df['celltype'], val_celltype)\n", + "test_celltype_mask = np.isin(_metadata_df['celltype'], test_celltype)\n", + "\n", + "split_array = -1*np.ones(_metadata_df.shape[0]).astype(int)\n", + "split_array[np.logical_and(train_chr_mask, train_celltype_mask)] = _split_dict['train']\n", + "split_array[np.logical_and(val_chr_mask, test_celltype_mask)] = _split_dict['test']\n", + "split_array[np.logical_and(val_chr_mask, val_celltype_mask)] = _split_dict['val-ood']\n", + "split_array[np.logical_and(val_chr_mask, train_celltype_mask)] = _split_dict['val-id']\n", + "_metadata_df['split'] = split_array\n", + "_split_array = split_array" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# get_input (idx)" + ] + }, + { + "cell_type": "code", + "execution_count": 153, + "metadata": {}, + "outputs": [], + "source": [ + "idx = 3\n", + "this_metadata = _metadata_df.iloc[idx, :]\n", + "\n", + "itime = time.time()\n", + "flank_size = 400\n", + "interval_start = this_metadata['start'] - flank_size\n", + "interval_end = this_metadata['stop'] + flank_size\n", + "dnase_this = _dnase_allcelltypes[this_metadata['celltype']][this_metadata['chr']][interval_start:interval_end]\n", + "seq_this = _seq_bp[this_metadata['chr']][interval_start:interval_end]\n", + "data = np.column_stack([seq_this, dnase_this])\n", + "# print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 154, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "4600" + ] + }, + "execution_count": 154, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data.shape\n", + "interval_end\n", + "# itime = time.time()\n", + "# np.save(os.path.join(data_dir, 'stmp.npy'), sa)\n", + "# print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 78, + "metadata": {}, + "outputs": [ + { + "ename": "TypeError", + "evalue": "can't convert np.ndarray of type numpy.object_. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mitime\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m metadata_array = torch.stack(\n\u001b[0;32m----> 3\u001b[0;31m (torch.LongTensor(metadata_df['chr'].values), \n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mLongTensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmetadata_df\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'celltype'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m self._y_array),\n", + "\u001b[0;31mTypeError\u001b[0m: can't convert np.ndarray of type numpy.object_. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool." + ] + } + ], + "source": [ + "itime = time.time()\n", + "metadata_array = torch.stack(\n", + " (torch.LongTensor(chr_ints), \n", + " torch.LongTensor(celltype_ints), \n", + " _y_array),\n", + " dim=1)\n", + "print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "python3 examples/run_expt.py -d camelyon17 --algorithm ERM --model densenet121 --split_scheme standard --groupby_fields hospital --loss_function cross_entropy --optimizer SGD \n", + "--lr 0.0001 --batch_size 32 --weight_decay 0 --n_epochs 10 --scheduler ReduceLROnPlateau --scheduler_metric_split val --scheduler_metric_name acc_avg \n", + "--log_dir log --log_every 50 --save_step 1000 --save_best --save_last --seed 0 --evaluate_all_splits --root_dir ROOTDIR" + ] + }, + { + "cell_type": "code", + "execution_count": 156, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name '_metadata_array' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0m_metadata_array\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mNameError\u001b[0m: name '_metadata_array' is not defined" + ] + } + ], + "source": [ + "_metadata_array" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from examples.models.model_attributes import model_attributes" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "ename": "ModuleNotFoundError", + "evalue": "No module named 'utils'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel_attributes\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mmodel_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mset_seed\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mLogger\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mCSVBatchLogger\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlog_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mParseKwargs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mload\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 21\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 22\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mdataset_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0moptimizer_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/dr_benchmark/examples/train.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0msave\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mIPython\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'utils'" + ] + } + ], + "source": [ + "import os, csv\n", + "import time\n", + "import argparse\n", + "import IPython\n", + "import pandas as pd\n", + "import torch\n", + "import torch.nn as nn\n", + "import torchvision\n", + "import sys\n", + "from collections import defaultdict\n", + "\n", + "# TODO: Replace this once we make wilds into an installed package\n", + "sys.path.insert(1, os.path.join(sys.path[0], '..'))\n", + "\n", + "from wilds.common.data_loaders import get_train_loader, get_eval_loader\n", + "from wilds.common.grouper import CombinatorialGrouper\n", + "from wilds.common.utils import get_counts\n", + "\n", + "from examples.models.model_attributes import model_attributes\n", + "from examples.utils import set_seed, Logger, CSVBatchLogger, log_args, ParseKwargs, load\n", + "from examples.train import train\n", + "from examples.data import dataset_attributes\n", + "from examples.optimizer import optimizer_attributes\n", + "from examples.scheduler import scheduler_attributes\n", + "from examples.loss import losses\n", + "from examples.utils import log_group_data\n", + "from examples.algorithms.constructors import algorithm_constructors\n", + "\n", + "\n", + "def initialize_algorithm(args, datasets, train_grouper):\n", + " train_dataset = datasets['train']['dataset']\n", + " train_loader = datasets['train']['loader']\n", + "\n", + " # Configure the final layer of the networks used\n", + " # The code below are defaults. Edit this if you need special config for your model.\n", + " if (train_dataset.is_classification) and (train_dataset.y_size == 1):\n", + " # For single-task classification, we have one output per class\n", + " d_out = train_dataset.n_classes\n", + " elif (not train_dataset.is_classification):\n", + " # For regression, we have one output per target dimension\n", + " d_out = train_dataset.y_size\n", + " else:\n", + " # TODO: Handle dataset-specific multi-task stuff here, e.g., for OGB\n", + " pass\n", + "\n", + " # Sanity checking input args\n", + " if args.algorithm == 'groupDRO':\n", + " assert args.train_loader_kwargs['uniform_over_groups']\n", + " elif args.algorithm in ['deepCORAL', 'IRM']:\n", + " assert args.train_loader == 'group'\n", + " assert args.train_loader_kwargs['uniform_over_groups']\n", + " assert args.train_loader_kwargs['distinct_groups']\n", + "\n", + " # Other config\n", + " n_train_steps = len(train_loader) * args.n_epochs\n", + " prediction_fn = dataset_attributes[args.dataset]['prediction_fn']\n", + " loss = losses[args.loss_function]\n", + " metric_constructor = dataset_attributes[args.dataset]['metric']\n", + " train_g = train_grouper.metadata_to_group(train_dataset.metadata_array)\n", + " is_group_in_train = get_counts(train_g, train_grouper.n_groups) > 0\n", + " algorithm_constructor = algorithm_constructors[args.algorithm]\n", + " algorithm = algorithm_constructor(\n", + " args=args,\n", + " d_out=d_out,\n", + " grouper=train_grouper,\n", + " prediction_fn=prediction_fn,\n", + " loss=loss,\n", + " metric_constructor=metric_constructor,\n", + " n_train_steps=n_train_steps,\n", + " is_group_in_train=is_group_in_train)\n", + " return algorithm" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "parser = argparse.ArgumentParser()\n", + "\n", + "# Dataset\n", + "parser.add_argument('-d', '--dataset', choices=dataset_attributes.keys(), required=True)\n", + "parser.add_argument('--split_scheme', default='standard',\n", + " help='Identifies how the train/val/test split is constructed. Choices are dataset-specific.')\n", + "parser.add_argument('--root_dir', default=None, required=True,\n", + " help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).')\n", + "parser.add_argument('--download', default=False, action='store_true',\n", + " help='If true, tries to downloads the dataset if it does not exist in root_dir.')\n", + "parser.add_argument('--frac', type=float, default=1.0,\n", + " help='Convenience parameter that scales all dataset splits down to the specified fraction, for development purposes.')\n", + "\n", + "# Loaders\n", + "parser.add_argument('--train_loader', choices=['standard', 'group'], default='standard')\n", + "parser.add_argument('--train_loader_kwargs', nargs='*', action=ParseKwargs, default={})\n", + "parser.add_argument('--eval_loader', choices=['standard', 'group'], default='standard')\n", + "parser.add_argument('--batch_size', type=int, default=32)\n", + "\n", + "# Model\n", + "parser.add_argument(\n", + " '--model',\n", + " choices=model_attributes.keys(),\n", + " default='resnet50')\n", + "parser.add_argument('--model_kwargs', nargs='*', action=ParseKwargs, default={},\n", + " help='keyword arguments for model initialization passed as key1=value1 key2=value2')\n", + "parser.add_argument('--train_from_scratch', action='store_true', default=False)\n", + "\n", + "# Algorithm and objective\n", + "parser.add_argument('--algorithm', required=True, choices=algorithm_constructors.keys())\n", + "parser.add_argument('--algorithm_kwargs', nargs='*', action=ParseKwargs, default={})\n", + "parser.add_argument('--groupby_fields', nargs='+', default=None)\n", + "parser.add_argument('--loss_function', required=True, choices = losses.keys()) #TODO: make default\n", + "parser.add_argument('--val_metric', default=None)\n", + "\n", + "# Optimization\n", + "parser.add_argument('--n_epochs', type=int, default=4)\n", + "parser.add_argument('--optimizer', default=None, choices=optimizer_attributes.keys())\n", + "parser.add_argument('--lr', type=float, required=True)\n", + "parser.add_argument('--weight_decay', type=float, required=True)\n", + "parser.add_argument('--optimizer_kwargs', nargs='*', action=ParseKwargs, default={})\n", + "parser.add_argument('--scheduler', default=None, choices=scheduler_attributes.keys())\n", + "parser.add_argument('--scheduler_kwargs', nargs='*', action=ParseKwargs, default={})\n", + "parser.add_argument('--scheduler_metric_split', choices=['train', 'val'], default='val')\n", + "parser.add_argument('--scheduler_metric_name')\n", + "\n", + "# Evaluation\n", + "parser.add_argument('--evaluate_all_splits', action='store_true', default=False)\n", + "parser.add_argument('--additional_eval_splits', nargs='+', default=[])\n", + "\n", + "# Misc\n", + "parser.add_argument('--device', default='cuda')\n", + "parser.add_argument('--seed', type=int, default=0)\n", + "parser.add_argument('--log_dir', default='./logs')\n", + "parser.add_argument('--log_every', default=50, type=int)\n", + "parser.add_argument('--save_step', type=int, default=None)\n", + "parser.add_argument('--save_best', action='store_true', default=False)\n", + "parser.add_argument('--save_last', action='store_true', default=False)\n", + "parser.add_argument('--save_outputs', action='store_true', default=False)\n", + "parser.add_argument('--no_group_logging', action='store_true', default=False)\n", + "\n", + "parser.add_argument('--resume', default=False, action='store_true')\n", + "\n", + "args = parser.parse_args()\n", + "\n", + "# Set defaults\n", + "if args.groupby_fields is None:\n", + " args.no_group_logging = True\n", + "if args.val_metric is None:\n", + " args.val_metric = dataset_attributes[args.dataset]['val_metric']\n", + "\n", + "## Initialize logs\n", + "if os.path.exists(args.log_dir) and args.resume:\n", + " resume=True\n", + " mode='a'\n", + "else:\n", + " resume=False\n", + " mode='w'\n", + "if not os.path.exists(args.log_dir):\n", + " os.makedirs(args.log_dir)\n", + "logger = Logger(os.path.join(args.log_dir, 'log.txt'), mode)\n", + "\n", + "# Record args\n", + "log_args(args, logger)\n", + "\n", + "# Set random seed\n", + "set_seed(args.seed)\n", + "\n", + "# Data\n", + "full_dataset = dataset_attributes[args.dataset]['constructor'](\n", + " root_dir=args.root_dir,\n", + " download=args.download,\n", + " split_scheme=args.split_scheme)\n", + "\n", + "# To implement data augmentation (i.e., have different transforms\n", + "# at training time vs. test time), modify these two lines:\n", + "train_transform = dataset_attributes[args.dataset]['transform'](args.model)\n", + "eval_transform = dataset_attributes[args.dataset]['transform'](args.model)\n", + "\n", + "train_grouper = CombinatorialGrouper(\n", + " dataset=full_dataset,\n", + " groupby_fields=args.groupby_fields)\n", + "\n", + "datasets = defaultdict(dict)\n", + "for split in full_dataset.split_dict.keys():\n", + " if split=='train':\n", + " transform = train_transform\n", + " verbose = True\n", + " elif split == 'val':\n", + " transform = eval_transform\n", + " verbose = True\n", + " else:\n", + " transform = eval_transform\n", + " verbose = False\n", + " # Get subset\n", + " datasets[split]['dataset'] = full_dataset.get_subset(\n", + " split,\n", + " frac=args.frac,\n", + " transform=transform)\n", + "\n", + " # Get loader\n", + " shared_loader_kwargs = {\n", + " 'num_workers': 4,\n", + " 'pin_memory': True,\n", + " 'batch_size': args.batch_size,\n", + " 'collate_fn': dataset_attributes[args.dataset]['collate']\n", + " }\n", + "\n", + " if split == 'train':\n", + " datasets[split]['loader'] = get_train_loader(\n", + " loader=args.train_loader,\n", + " dataset=datasets[split]['dataset'],\n", + " grouper=train_grouper,\n", + " train_loader_kwargs=args.train_loader_kwargs,\n", + " **shared_loader_kwargs)\n", + " else:\n", + " datasets[split]['loader'] = get_eval_loader(\n", + " loader=args.eval_loader,\n", + " dataset=datasets[split]['dataset'],\n", + " grouper=train_grouper,\n", + " **shared_loader_kwargs)\n", + "\n", + " # Set fields\n", + " datasets[split]['split'] = split\n", + " datasets[split]['name'] = full_dataset.split_names[split]\n", + " datasets[split]['verbose'] = verbose\n", + " # Loggers\n", + " datasets[split]['eval_logger'] = CSVBatchLogger(\n", + " os.path.join(args.log_dir, f'{split}_eval.csv'), mode=mode)\n", + " datasets[split]['algo_logger'] = CSVBatchLogger(\n", + " os.path.join(args.log_dir, f'{split}_algo.csv'), mode=mode)\n", + "\n", + "# Logging dataset info\n", + "if args.no_group_logging and full_dataset.is_classification and full_dataset.y_size==1:\n", + " log_grouper = CombinatorialGrouper(\n", + " dataset=full_dataset,\n", + " groupby_fields=['y'])\n", + "elif args.no_group_logging:\n", + " log_grouper = None\n", + "else:\n", + " log_grouper = train_grouper\n", + "log_group_data(args, datasets, log_grouper, logger)\n", + "\n", + "## Initialize algorithm\n", + "algorithm = initialize_algorithm(args, datasets, train_grouper)\n", + "\n", + "## Load saved results if resuming\n", + "if resume:\n", + " save_path = os.path.join(args.log_dir, 'last_model.pth')\n", + " prev_epoch, best_val_metric = load(algorithm, save_path)\n", + " epoch_offset = prev_epoch + 1\n", + "else:\n", + " epoch_offset=0\n", + " best_val_metric=None\n", + "\n", + "train(algorithm,\n", + " datasets,\n", + " logger,\n", + " args,\n", + " epoch_offset=epoch_offset,\n", + " best_val_metric=best_val_metric)\n", + "\n", + "logger.close()\n", + "for split in datasets:\n", + " datasets[split]['eval_logger'].close()\n", + " datasets[split]['algo_logger'].close()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 2", + "language": "python", + "name": "python2" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From 4ca2160fd16cd9e7335afbf6747e2fca47e9bf3b Mon Sep 17 00:00:00 2001 From: aikanor Date: Tue, 2 Feb 2021 11:38:01 -0800 Subject: [PATCH 052/244] model tweak --- examples/models/CNN_genome.py | 151 +--------------------------------- sandbox_model.ipynb | 32 +++++++ 2 files changed, 34 insertions(+), 149 deletions(-) diff --git a/examples/models/CNN_genome.py b/examples/models/CNN_genome.py index f0115322..75295cd3 100644 --- a/examples/models/CNN_genome.py +++ b/examples/models/CNN_genome.py @@ -16,7 +16,7 @@ class Beagle(nn.Module): TODO: Finish docstring. """ - def __init__(self, args): + def __init__(self): """ Parameters ---------- @@ -25,7 +25,7 @@ def __init__(self, args): """ super(Beagle, self).__init__() - self.dropout = args.dropout + self.dropout = 0.3 self.num_cell_types = 1 self.conv1 = nn.Conv2d(5, 300, (19, 1), stride = (1, 1), padding=(9,0)) self.conv2 = nn.Conv2d(300, 200, (11, 1), stride = (1, 1), padding = (5,0)) @@ -60,150 +60,3 @@ def forward(self, s): s = self.fc3(s) return s, conv_out - - - -#class MLP(nn.Module): -# """Just an MLP""" -# def __init__(self, n_inputs, n_outputs, width, depth, drop_out): -# super(MLP, self).__init__() -# -# self.input = nn.Linear(n_inputs, width) -# self.dropout = nn.Dropout(dropout) -# self.hiddens = nn.ModuleList([ -# nn.Linear(width,width) -# for _ in range(depth-2)]) -# self.output = nn.Linear(width, n_outputs) -# self.n_outputs = n_outputs -# -# def forward(self, x): -# x = self.input(x) -# x = self.dropout(x) -# x = F.relu(x) -# for hidden in self.hiddens: -# x = hidden(x) -# x = self.dropout(x) -# x = F.relu(x) -# x = self.output(x) -# return x - - -""" -DeepSEA architecture (Zhou & Troyanskaya, 2015). -Based on https://github.com/FunctionLab/selene/blob/master/models/deepsea.py -""" - -class DeepSEA(nn.Module): - def __init__(self, sequence_length, n_genomic_features): - """ - Parameters - ---------- - sequence_length : int - n_genomic_features : int - """ - super(DeepSEA, self).__init__() - conv_kernel_size = 8 - pool_kernel_size = 4 - - self.conv_net = nn.Sequential( - nn.Conv1d(4, 320, kernel_size=conv_kernel_size), - nn.ReLU(inplace=True), - nn.MaxPool1d( - kernel_size=pool_kernel_size, stride=pool_kernel_size), - nn.Dropout(p=0.2), - - nn.Conv1d(320, 480, kernel_size=conv_kernel_size), - nn.ReLU(inplace=True), - nn.MaxPool1d( - kernel_size=pool_kernel_size, stride=pool_kernel_size), - nn.Dropout(p=0.2), - - nn.Conv1d(480, 960, kernel_size=conv_kernel_size), - nn.ReLU(inplace=True), - nn.Dropout(p=0.5)) - - reduce_by = conv_kernel_size - 1 - pool_kernel_size = float(pool_kernel_size) - self.n_channels = int( - np.floor( - (np.floor( - (sequence_length - reduce_by) / pool_kernel_size) - - reduce_by) / pool_kernel_size) - - reduce_by) - self.classifier = nn.Sequential( - nn.Linear(960 * self.n_channels, n_genomic_features), - nn.ReLU(inplace=True), - nn.Linear(n_genomic_features, n_genomic_features), - nn.Sigmoid()) - - def forward(self, x): - """Forward propagation of a batch. - """ - out = self.conv_net(x) - reshape_out = out.view(out.size(0), 960 * self.n_channels) - predict = self.classifier(reshape_out) - return predict - -""" -def criterion(): - return nn.BCELoss() - -def get_optimizer(lr): - # The optimizer and the parameters with which to initialize the optimizer. At a later time, we initialize the optimizer by also passing in the model parameters (`model.parameters()`). We cannot initialize the optimizer until the model has been initialized. - return (torch.optim.SGD, {"lr": lr, "weight_decay": 1e-6, "momentum": 0.9}) -""" - - - -""" -DanQ architecture (Quang & Xie, 2016). -""" - -class DanQ(nn.Module): - def __init__(self, sequence_length, n_genomic_features): - """ - Parameters - ---------- - sequence_length : int - Input sequence length - n_genomic_features : int - Total number of features to predict - """ - super(DanQ, self).__init__() - self.nnet = nn.Sequential( - nn.Conv1d(4, 320, kernel_size=26), - nn.ReLU(inplace=True), - nn.MaxPool1d( - kernel_size=13, stride=13), - nn.Dropout(0.2)) - - self.bdlstm = nn.Sequential(nn.LSTM(320, 320, num_layers=1, batch_first=True, bidirectional=True)) - - self._n_channels = math.floor( - (sequence_length - 25) / 13) - self.classifier = nn.Sequential( - nn.Dropout(0.5), - nn.Linear(self._n_channels * 640, 925), - nn.ReLU(inplace=True), - nn.Linear(925, n_genomic_features), - nn.Sigmoid()) - - def forward(self, x): - """Forward propagation of a batch. - """ - out = self.nnet(x) - reshape_out = out.transpose(0, 1).transpose(0, 2) - out, _ = self.bdlstm(reshape_out) - out = out.transpose(0, 1) - reshape_out = out.contiguous().view( - out.size(0), 640 * self._n_channels) - predict = self.classifier(reshape_out) - return predict - -""" -def criterion(): - return nn.BCELoss() - -def get_optimizer(lr): - return (torch.optim.RMSprop, {"lr": lr}) -""" \ No newline at end of file diff --git a/sandbox_model.ipynb b/sandbox_model.ipynb index c264d747..885c8a59 100644 --- a/sandbox_model.ipynb +++ b/sandbox_model.ipynb @@ -288,6 +288,38 @@ "_dnase_allcelltypes" ] }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "from examples.models import CNN_genome" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "ename": "TypeError", + "evalue": "unbound method parameters() must be called with Beagle instance as first argument (got nothing instead)", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# def count_parameters(model):\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;31m# return sum(p.numel() for p in model.parameters() if p.requires_grad)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mCNN_genome\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mBeagle\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparameters\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m: unbound method parameters() must be called with Beagle instance as first argument (got nothing instead)" + ] + } + ], + "source": [ + "# def count_parameters(model):\n", + "# return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", + "CNN_genome.Beagle.parameters()" + ] + }, { "cell_type": "code", "execution_count": 14, From f3d344f5953108214ef8d9b6d97486e50bc13241 Mon Sep 17 00:00:00 2001 From: aikanor Date: Sun, 7 Feb 2021 16:24:31 -0800 Subject: [PATCH 053/244] preprocessing changes --- .../sandbox_data-checkpoint.ipynb | 952 ------------------ dataset_preprocessing/encode-tfbs/README.md | 18 + .../encode-tfbs/prep_accessibility.py | 41 + .../encode-tfbs/prep_datasets.ipynb | 279 +++++ .../encode-tfbs/prep_sequence.py | 151 +++ .../sandbox_data.ipynb => sandbox_data.ipynb | 77 +- sandbox_model.ipynb | 486 +++++---- 7 files changed, 812 insertions(+), 1192 deletions(-) delete mode 100644 dataset_preprocessing/encode-tfbs/.ipynb_checkpoints/sandbox_data-checkpoint.ipynb create mode 100644 dataset_preprocessing/encode-tfbs/README.md create mode 100644 dataset_preprocessing/encode-tfbs/prep_accessibility.py create mode 100644 dataset_preprocessing/encode-tfbs/prep_datasets.ipynb create mode 100644 dataset_preprocessing/encode-tfbs/prep_sequence.py rename dataset_preprocessing/encode-tfbs/sandbox_data.ipynb => sandbox_data.ipynb (97%) diff --git a/dataset_preprocessing/encode-tfbs/.ipynb_checkpoints/sandbox_data-checkpoint.ipynb b/dataset_preprocessing/encode-tfbs/.ipynb_checkpoints/sandbox_data-checkpoint.ipynb deleted file mode 100644 index b2e74829..00000000 --- a/dataset_preprocessing/encode-tfbs/.ipynb_checkpoints/sandbox_data-checkpoint.ipynb +++ /dev/null @@ -1,952 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Initialize dataset object" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "57.5289368629\n", - "65.2459537983\n" - ] - } - ], - "source": [ - "import numpy as np, pandas as pd, os, time\n", - "import torch, torchvision\n", - "\n", - "data_dir = '/oak/stanford/groups/akundaje/abalsubr/DREAM/wilds/codalab_archive/'\n", - "tf = 'MAX'\n", - "itime = time.time()\n", - "train_chr = pd.read_csv(os.path.join(data_dir, 'labels/{}.train.labels.tsv.gz'.format(tf)), sep='\\t')\n", - "print(time.time() - itime)\n", - "val_chr = pd.read_csv(os.path.join(data_dir, 'labels/{}.val.labels.tsv.gz'.format(tf)), sep='\\t')\n", - "print(time.time() - itime)" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']\n", - "val_celltype = ['A549']\n", - "test_celltype = ['GM12878']\n", - "all_celltypes = train_celltypes + val_celltype + test_celltype\n", - "\n", - "metadata_map = {}\n", - "metadata_map['chr'] = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX']\n", - "metadata_map['celltype'] = all_celltypes\n", - "\n", - "_split_dict = {\n", - " 'train': 0,\n", - " 'val-id': 1,\n", - " 'test': 2,\n", - " 'val-ood': 3\n", - "}\n", - "_split_names = {\n", - " 'train': 'Train',\n", - " 'val-id': 'Validation (ID)',\n", - " 'test': 'Test',\n", - " 'val-ood': 'Validation (OOD)',\n", - "}\n", - "_split_scheme = 'standard'" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "0.0467748641968\n", - "('chr1', 4.52302885055542)\n", - "('chr2', 8.645489931106567)\n", - "('chr3', 11.959153890609741)\n", - "('chr4', 15.15813684463501)\n", - "('chr5', 18.22238802909851)\n", - "('chr6', 21.19420099258423)\n", - "('chr7', 23.940655946731567)\n", - "('chr8', 26.415233850479126)\n", - "('chr9', 28.833614826202393)\n", - "('chr10', 31.08920383453369)\n", - "('chr11', 33.37020301818848)\n", - "('chr12', 35.98973989486694)\n", - "('chr13', 37.88540601730347)\n", - "('chr14', 39.68082284927368)\n", - "('chr15', 41.242313861846924)\n", - "('chr16', 42.74874496459961)\n", - "('chr17', 44.12280797958374)\n", - "('chr18', 45.46893382072449)\n", - "('chr19', 46.50577902793884)\n", - "('chr20', 47.59563183784485)\n", - "('chr21', 48.31779384613037)\n", - "('chr22', 49.17265295982361)\n", - "('chrX', 51.75806999206543)\n", - "('H1-hESC', 25.880441904067993)\n", - "('HCT116', 50.130937814712524)\n", - "('HeLa-S3', 75.29559993743896)\n", - "('HepG2', 102.25979495048523)\n", - "('K562', 128.43050694465637)\n", - "('A549', 154.80679488182068)\n", - "('GM12878', 182.0279529094696)\n" - ] - } - ], - "source": [ - "itime = time.time()\n", - "sequence_filename = os.path.join(data_dir, 'sequence.npz')\n", - "seq_arr = np.load(sequence_filename)\n", - "print(time.time() - itime)\n", - "\n", - "itime = time.time()\n", - "_seq_bp = {}\n", - "for chrom in seq_arr:\n", - " _seq_bp[chrom] = seq_arr[chrom]\n", - " print(chrom, time.time() - itime)\n", - "\n", - "itime = time.time()\n", - "_dnase_allcelltypes = {}\n", - "for ct in all_celltypes:\n", - " dnase_filename = os.path.join(data_dir, '{}_dnase.npz'.format(ct))\n", - " dnase_npz_file = np.load(dnase_filename)\n", - " _dnase_allcelltypes[ct] = {}\n", - " for chrom in _seq_bp:\n", - " _dnase_allcelltypes[ct][chrom] = dnase_npz_file[chrom]\n", - " print(ct, time.time() - itime)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "ename": "NameError", - "evalue": "name 'all_df' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# len(_dnase_allcelltypes)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mall_df\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;31mNameError\u001b[0m: name 'all_df' is not defined" - ] - } - ], - "source": [ - "# len(_dnase_allcelltypes)\n", - "all_df" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "ename": "AttributeError", - "evalue": "'module' object has no attribute 'isin'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mtr_chrs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m'chr2'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr9'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr11'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mte_chrs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m'chr1'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr8'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr21'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mtraining_df\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'chr'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtr_chrs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mval_df\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mval_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mval_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'chr'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mte_chrs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mall_df\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtraining_df\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_df\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mAttributeError\u001b[0m: 'module' object has no attribute 'isin'" - ] - } - ], - "source": [ - "tr_chrs = ['chr2', 'chr9', 'chr11']\n", - "te_chrs = ['chr1', 'chr8', 'chr21']\n", - "training_df = train_chr[np.isin(train_chr['chr'], tr_chrs)]\n", - "val_df = val_chr[np.isin(val_chr['chr'], te_chrs)]\n", - "all_df = pd.concat([training_df, val_df])\n", - "\n", - "#filter_msk = all_df['start'] >= 0\n", - "filter_msk = all_df['start']%1000 == 0\n", - "all_df = all_df[filter_msk]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "itime = time.time()\n", - "pd_list = []\n", - "for ct in all_celltypes:\n", - " tc_chr = all_df[['chr', 'start', 'stop', ct]]\n", - " tc_chr.columns = ['chr', 'start', 'stop', 'y']\n", - " tc_chr['celltype'] = ct\n", - " pd_list.append(tc_chr)\n", - "metadata_df = pd.concat(pd_list)\n", - "print(time.time() - itime)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "itime = time.time()\n", - "y_array = metadata_df['y'].replace({'U': 0, 'B': 1, 'A': -1}).values\n", - "non_ambig_mask = (y_array != -1)\n", - "metadata_df['y'] = y_array\n", - "_metadata_df = metadata_df[non_ambig_mask]\n", - "_y_array = torch.LongTensor(y_array[non_ambig_mask])\n", - "print(time.time() - itime)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "itime = time.time()\n", - "chr_ints = _metadata_df['chr'].replace(dict( [(y, x) for x, y in enumerate(metadata_map['chr'])] )).values\n", - "celltype_ints = _metadata_df['celltype'].replace(dict( [(y, x) for x, y in enumerate(metadata_map['celltype'])] )).values\n", - "print(time.time() - itime)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "train_chr_mask = np.isin(_metadata_df['chr'], tr_chrs)\n", - "val_chr_mask = np.isin(_metadata_df['chr'], te_chrs)\n", - "train_celltype_mask = np.isin(_metadata_df['celltype'], train_celltypes)\n", - "val_celltype_mask = np.isin(_metadata_df['celltype'], val_celltype)\n", - "test_celltype_mask = np.isin(_metadata_df['celltype'], test_celltype)\n", - "\n", - "split_array = -1*np.ones(_metadata_df.shape[0]).astype(int)\n", - "split_array[np.logical_and(train_chr_mask, train_celltype_mask)] = _split_dict['train']\n", - "split_array[np.logical_and(val_chr_mask, test_celltype_mask)] = _split_dict['test']\n", - "split_array[np.logical_and(val_chr_mask, val_celltype_mask)] = _split_dict['val-ood']\n", - "split_array[np.logical_and(val_chr_mask, train_celltype_mask)] = _split_dict['val-id']\n", - "_metadata_df['split'] = split_array\n", - "_split_array = split_array" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "ename": "ImportError", - "evalue": "No module named data", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mDataLoader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mdataset_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;31mImportError\u001b[0m: No module named data" - ] - } - ], - "source": [ - "from torch.utils.data import DataLoader\n", - "from data import dataset_attributes" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [], - "source": [ - "from PIL import Image\n", - "import argparse\n", - "class ParseKwargs(argparse.Action):\n", - " def __call__(self, parser, namespace, values, option_string=None):\n", - " setattr(namespace, self.dest, dict())\n", - " for value in values:\n", - " key, value_str = value.split('=')\n", - " if value_str.replace('-','').isnumeric():\n", - " processed_val = int(value_str)\n", - " elif value_str.replace('-','').replace('.','').isnumeric():\n", - " processed_val = float(value_str)\n", - " elif value_str in ['True', 'true']:\n", - " processed_val = True\n", - " elif value_str in ['False', 'false']:\n", - " processed_val = False\n", - " else:\n", - " processed_val = value_str\n", - " getattr(namespace, self.dest)[key] = processed_val" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [ - { - "ename": "NameError", - "evalue": "name 'algorithm_constructors' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[0;31m# Algorithm and objective\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 33\u001b[0;31m \u001b[0mparser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_argument\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'--algorithm'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrequired\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mchoices\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0malgorithm_constructors\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 34\u001b[0m \u001b[0mparser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_argument\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'--algorithm_kwargs'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnargs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'*'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maction\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mParseKwargs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdefault\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 35\u001b[0m \u001b[0mparser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_argument\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'--groupby_fields'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnargs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'+'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdefault\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mNameError\u001b[0m: name 'algorithm_constructors' is not defined" - ] - } - ], - "source": [ - "ROOTDIR = '/oak/stanford/groups/akundaje/abalsubr/wilds_other'\n", - "args_kw = \"-d camelyon17 --algorithm ERM --model densenet121 --split_scheme standard --groupby_fields hospital --loss_function cross_entropy --optimizer SGD --lr 0.0001 --batch_size 32 --weight_decay 0 --n_epochs 10 --scheduler ReduceLROnPlateau --scheduler_metric_split val --scheduler_metric_name acc_avg --log_dir log --log_every 50 --save_step 1000 --save_best --save_last --seed 0 --evaluate_all_splits --root_dir {}\".format(\n", - " ROOTDIR).split()\n", - "\n", - "parser = argparse.ArgumentParser()\n", - "\n", - "# Dataset\n", - "parser.add_argument('-d', '--dataset', choices=['encodeTFBS', 'amazon', 'camelyon17', 'celebA', 'civilcomments', 'iwildcam', 'waterbirds', 'yelp', 'poverty', 'fmow', 'ogbg-molpcba'], required=True)\n", - "parser.add_argument('--split_scheme', default='standard',\n", - " help='Identifies how the train/val/test split is constructed. Choices are dataset-specific.')\n", - "parser.add_argument('--dataset_kwargs', nargs='*', action=ParseKwargs, default={})\n", - "parser.add_argument('--root_dir', default=None, required=True,\n", - " help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).')\n", - "parser.add_argument('--download', default=False, action='store_true',\n", - " help='If true, tries to downloads the dataset if it does not exist in root_dir.')\n", - "parser.add_argument('--frac', type=float, default=1.0,\n", - " help='Convenience parameter that scales all dataset splits down to the specified fraction, for development purposes.')\n", - "\n", - "# Loaders\n", - "parser.add_argument('--train_loader', choices=['standard', 'group'], default='standard')\n", - "parser.add_argument('--train_loader_kwargs', nargs='*', action=ParseKwargs, default={})\n", - "parser.add_argument('--eval_loader', choices=['standard', 'group'], default='standard')\n", - "parser.add_argument('--batch_size', type=int, default=32)\n", - "parser.add_argument('--no_pin_memory', action='store_true') # TODO: put as loader_kwargs\n", - "parser.add_argument('--num_workers', type=int, default=4) # TODO: put as loader kwargs\n", - "\n", - "# Model\n", - "parser.add_argument(\n", - " '--model',\n", - " choices=['bert-base-uncased', 'inception_v3', 'densenet121', 'wideresnet50', 'resnet50', 'gin-virtual', 'resnet18_ms'],\n", - " default='resnet50')\n", - "parser.add_argument('--model_kwargs', nargs='*', action=ParseKwargs, default={},\n", - " help='keyword arguments for model initialization passed as key1=value1 key2=value2')\n", - "parser.add_argument('--train_from_scratch', action='store_true', default=False)\n", - "\n", - "# Algorithm and objective\n", - "parser.add_argument('--algorithm', required=True, choices=algorithm_constructors.keys())\n", - "parser.add_argument('--algorithm_kwargs', nargs='*', action=ParseKwargs, default={})\n", - "parser.add_argument('--groupby_fields', nargs='+', default=None)\n", - "parser.add_argument('--loss_function', required=True, choices = losses.keys()) #TODO: make default\n", - "parser.add_argument('--val_metric', default=None)\n", - "\n", - "# Optimization\n", - "parser.add_argument('--n_epochs', type=int, default=4)\n", - "parser.add_argument('--optimizer', default=None, choices=optimizer_attributes.keys())\n", - "parser.add_argument('--lr', type=float, required=True)\n", - "parser.add_argument('--weight_decay', type=float, required=True)\n", - "parser.add_argument('--optimizer_kwargs', nargs='*', action=ParseKwargs, default={})\n", - "parser.add_argument('--scheduler', default=None, choices=scheduler_attributes.keys())\n", - "parser.add_argument('--scheduler_kwargs', nargs='*', action=ParseKwargs, default={})\n", - "parser.add_argument('--scheduler_metric_split', choices=['train', 'val'], default='val')\n", - "parser.add_argument('--scheduler_metric_name')\n", - "\n", - "# Evaluation\n", - "parser.add_argument('--evaluate_all_splits', action='store_true', default=False)\n", - "parser.add_argument('--additional_eval_splits', nargs='+', default=[])\n", - "\n", - "# Misc\n", - "parser.add_argument('--device', type=int, default=0)\n", - "parser.add_argument('--seed', type=int, default=0)\n", - "parser.add_argument('--log_dir', default='./logs')\n", - "parser.add_argument('--log_every', default=50, type=int)\n", - "parser.add_argument('--save_step', type=int, default=None)\n", - "parser.add_argument('--save_best', action='store_true', default=False)\n", - "parser.add_argument('--save_last', action='store_true', default=False)\n", - "parser.add_argument('--save_outputs', action='store_true', default=False)\n", - "parser.add_argument('--no_group_logging', action='store_true', default=False)\n", - "parser.add_argument('--val_metric_decreasing', action='store_true', default=False)\n", - "parser.add_argument('--use_wandb', action='store_true', default=False)\n", - "parser.add_argument('--progress_bar', action='store_true', default=False)\n", - "parser.add_argument('--resume', default=False, action='store_true')\n", - "parser.add_argument('--eval_only', default=False, action='store_true')\n", - "\n", - "args = parser.parse_args(args_kw)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# get_input (idx)" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": {}, - "outputs": [ - { - "ename": "NameError", - "evalue": "name '_metadata_df' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0midx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mthis_metadata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_metadata_df\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miloc\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mitime\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mflank_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m400\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mNameError\u001b[0m: name '_metadata_df' is not defined" - ] - } - ], - "source": [ - "idx = 3\n", - "this_metadata = _metadata_df.iloc[idx, :]\n", - "\n", - "itime = time.time()\n", - "flank_size = 400\n", - "interval_start = this_metadata['start'] - flank_size\n", - "interval_end = this_metadata['stop'] + flank_size\n", - "dnase_this = _dnase_allcelltypes[this_metadata['celltype']][this_metadata['chr']][interval_start:interval_end]\n", - "seq_this = _seq_bp[this_metadata['chr']][interval_start:interval_end]\n", - "data = np.column_stack([seq_this, dnase_this])\n", - "# print(time.time() - itime)" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "0.028102874755859375\n" - ] - } - ], - "source": [ - "itime = time.time()\n", - "metadata_array = torch.stack(\n", - " (torch.LongTensor(chr_ints), \n", - " torch.LongTensor(celltype_ints), \n", - " _y_array),\n", - " dim=1)\n", - "print(time.time() - itime)" - ] - }, - { - "cell_type": "code", - "execution_count": 34, - "metadata": {}, - "outputs": [ - { - "ename": "ModuleNotFoundError", - "evalue": "No module named 'torch_scatter'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m#data.shape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mwilds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata_loaders\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mget_train_loader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mget_eval_loader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m~/dr_benchmark/wilds/common/data_loaders.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mDataLoader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msampler\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mWeightedRandomSampler\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mSubsetRandomSampler\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mwilds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mget_counts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msplit_into_groups\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mIPython\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/dr_benchmark/wilds/common/utils.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch_scatter\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnumpy\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mSubset\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mpandas\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapi\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtypes\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mCategoricalDtype\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'torch_scatter'" - ] - } - ], - "source": [ - "#data.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 157, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "4600" - ] - }, - "execution_count": 157, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "data.shape\n", - "interval_end\n", - "# itime = time.time()\n", - "# np.save(os.path.join(data_dir, 'stmp.npy'), sa)\n", - "# print(time.time() - itime)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Run training experiment" - ] - }, - { - "cell_type": "code", - "execution_count": 167, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'python3 examples/run_expt.py -d encodeTFBS --algorithm ERM --model densenet121 --split_scheme standard --groupby_fields hospital --loss_function cross_entropy --optimizer SGD --lr 0.0001 --batch_size 32 --weight_decay 0 --n_epochs 10 --scheduler ReduceLROnPlateau --scheduler_metric_split val --scheduler_metric_name acc_avg --log_dir log --log_every 50 --save_step 1000 --save_best --save_last --seed 0 --evaluate_all_splits --root_dir ROOTDIR'" - ] - }, - "execution_count": 167, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "cmdstr = \"python3 examples/run_expt.py -d encodeTFBS --algorithm ERM --model densenet121 --split_scheme standard --groupby_fields hospital --loss_function cross_entropy\"\n", - "cmdstr += \" \"\n", - "cmdstr += \"--optimizer SGD --lr 0.0001 --batch_size 32 --weight_decay 0 --n_epochs 10 --scheduler ReduceLROnPlateau --scheduler_metric_split val --scheduler_metric_name acc_avg\"\n", - "cmdstr += \" \"\n", - "cmdstr += \"--log_dir log --log_every 50 --save_step 1000 --save_best --save_last --seed 0 --evaluate_all_splits --root_dir ROOTDIR\"\n", - "cmdstr" - ] - }, - { - "cell_type": "code", - "execution_count": 164, - "metadata": {}, - "outputs": [ - { - "ename": "NameError", - "evalue": "name '_metadata_array' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0m_metadata_array\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;31mNameError\u001b[0m: name '_metadata_array' is not defined" - ] - } - ], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 165, - "metadata": {}, - "outputs": [ - { - "ename": "ModuleNotFoundError", - "evalue": "No module named 'torch_scatter'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minsert\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'..'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 16\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mwilds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata_loaders\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mget_train_loader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mget_eval_loader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 17\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mwilds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrouper\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mCombinatorialGrouper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mwilds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mget_counts\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/dr_benchmark/wilds/common/data_loaders.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mDataLoader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msampler\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mWeightedRandomSampler\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mSubsetRandomSampler\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mwilds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mget_counts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msplit_into_groups\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mIPython\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/dr_benchmark/wilds/common/utils.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch_scatter\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnumpy\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mSubset\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mpandas\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapi\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtypes\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mCategoricalDtype\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'torch_scatter'" - ] - } - ], - "source": [ - "import os, csv\n", - "import time\n", - "import argparse\n", - "import IPython\n", - "import pandas as pd\n", - "import torch\n", - "import torch.nn as nn\n", - "import torchvision\n", - "import sys\n", - "from collections import defaultdict\n", - "# torch.multiprocessing.set_sharing_strategy('file_system')\n", - "\n", - "# TODO: Replace this once we make wilds into an installed package\n", - "sys.path.insert(1, os.path.join(sys.path[0], '..'))\n", - "\n", - "from wilds.common.data_loaders import get_train_loader, get_eval_loader\n", - "from wilds.common.grouper import CombinatorialGrouper\n", - "from wilds.common.utils import get_counts\n", - "\n", - "from models.model_attributes import model_attributes\n", - "from utils import set_seed, Logger, BatchLogger, log_args, ParseKwargs, load\n", - "from train import train, evaluate\n", - "from data import dataset_attributes\n", - "from optimizer import optimizer_attributes\n", - "from scheduler import scheduler_attributes\n", - "from loss import losses\n", - "from utils import log_group_data\n", - "from algorithms.constructors import algorithm_constructors" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "from examples.models.model_attributes import model_attributes" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "ename": "ModuleNotFoundError", - "evalue": "No module named 'utils'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel_attributes\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mmodel_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mset_seed\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mLogger\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mCSVBatchLogger\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlog_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mParseKwargs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mload\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 21\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 22\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mdataset_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0moptimizer_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/dr_benchmark/examples/train.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0msave\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mIPython\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'utils'" - ] - } - ], - "source": [ - "def initialize_algorithm(args, datasets, train_grouper):\n", - " train_dataset = datasets['train']['dataset']\n", - " train_loader = datasets['train']['loader']\n", - "\n", - " # Configure the final layer of the networks used\n", - " # The code below are defaults. Edit this if you need special config for your model.\n", - " if (train_dataset.is_classification) and (train_dataset.y_size == 1):\n", - " # For single-task classification, we have one output per class\n", - " d_out = train_dataset.n_classes\n", - " elif (train_dataset.is_classification) and (train_dataset.y_size > 1) and (train_dataset.n_classes == 2):\n", - " # For multi-task binary classification (each output is the logit for each binary class)\n", - " d_out = train_dataset.y_size\n", - " elif (not train_dataset.is_classification):\n", - " # For regression, we have one output per target dimension\n", - " d_out = train_dataset.y_size\n", - " else:\n", - " raise RuntimeError('d_out not defined.')\n", - " \n", - "\n", - " # Sanity checking input args\n", - " if args.algorithm == 'groupDRO':\n", - " assert args.train_loader_kwargs['uniform_over_groups']\n", - " elif args.algorithm in ['deepCORAL', 'IRM']:\n", - " assert args.train_loader == 'group'\n", - " assert args.train_loader_kwargs['uniform_over_groups']\n", - " assert args.train_loader_kwargs['distinct_groups']\n", - "\n", - " # Other config\n", - " n_train_steps = len(train_loader) * args.n_epochs\n", - "# prediction_fn = dataset_attributes[args.dataset]['prediction_fn']\n", - " loss = losses[args.loss_function]\n", - " metric = dataset_attributes[args.dataset]['metric']\n", - " train_g = train_grouper.metadata_to_group(train_dataset.metadata_array)\n", - " is_group_in_train = get_counts(train_g, train_grouper.n_groups) > 0\n", - " algorithm_constructor = algorithm_constructors[args.algorithm]\n", - " algorithm = algorithm_constructor(\n", - " args=args,\n", - " d_out=d_out,\n", - " grouper=train_grouper,\n", - " loss=loss,\n", - " metric=metric,\n", - " n_train_steps=n_train_steps,\n", - " is_group_in_train=is_group_in_train)\n", - " return algorithm" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def main():\n", - " parser = argparse.ArgumentParser()\n", - "\n", - " # Dataset\n", - " parser.add_argument('-d', '--dataset', choices=dataset_attributes.keys(), required=True)\n", - " parser.add_argument('--split_scheme', default='standard',\n", - " help='Identifies how the train/val/test split is constructed. Choices are dataset-specific.')\n", - " parser.add_argument('--dataset_kwargs', nargs='*', action=ParseKwargs, default={})\n", - " parser.add_argument('--root_dir', default=None, required=True,\n", - " help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).')\n", - " parser.add_argument('--download', default=False, action='store_true',\n", - " help='If true, tries to downloads the dataset if it does not exist in root_dir.')\n", - " parser.add_argument('--frac', type=float, default=1.0,\n", - " help='Convenience parameter that scales all dataset splits down to the specified fraction, for development purposes.')\n", - "\n", - " # Loaders\n", - " parser.add_argument('--train_loader', choices=['standard', 'group'], default='standard')\n", - " parser.add_argument('--train_loader_kwargs', nargs='*', action=ParseKwargs, default={})\n", - " parser.add_argument('--eval_loader', choices=['standard', 'group'], default='standard')\n", - " parser.add_argument('--batch_size', type=int, default=32)\n", - " parser.add_argument('--no_pin_memory', action='store_true') # TODO: put as loader_kwargs\n", - " parser.add_argument('--num_workers', type=int, default=4) # TODO: put as loader kwargs\n", - "\n", - " # Model\n", - " parser.add_argument(\n", - " '--model',\n", - " choices=model_attributes.keys(),\n", - " default='resnet50')\n", - " parser.add_argument('--model_kwargs', nargs='*', action=ParseKwargs, default={},\n", - " help='keyword arguments for model initialization passed as key1=value1 key2=value2')\n", - " parser.add_argument('--train_from_scratch', action='store_true', default=False)\n", - "\n", - " # Algorithm and objective\n", - " parser.add_argument('--algorithm', required=True, choices=algorithm_constructors.keys())\n", - " parser.add_argument('--algorithm_kwargs', nargs='*', action=ParseKwargs, default={})\n", - " parser.add_argument('--groupby_fields', nargs='+', default=None)\n", - " parser.add_argument('--loss_function', required=True, choices = losses.keys()) #TODO: make default\n", - " parser.add_argument('--val_metric', default=None)\n", - "\n", - " # Optimization\n", - " parser.add_argument('--n_epochs', type=int, default=4)\n", - " parser.add_argument('--optimizer', default=None, choices=optimizer_attributes.keys())\n", - " parser.add_argument('--lr', type=float, required=True)\n", - " parser.add_argument('--weight_decay', type=float, required=True)\n", - " parser.add_argument('--optimizer_kwargs', nargs='*', action=ParseKwargs, default={})\n", - " parser.add_argument('--scheduler', default=None, choices=scheduler_attributes.keys())\n", - " parser.add_argument('--scheduler_kwargs', nargs='*', action=ParseKwargs, default={})\n", - " parser.add_argument('--scheduler_metric_split', choices=['train', 'val'], default='val')\n", - " parser.add_argument('--scheduler_metric_name')\n", - "\n", - " # Evaluation\n", - " parser.add_argument('--evaluate_all_splits', action='store_true', default=False)\n", - " parser.add_argument('--additional_eval_splits', nargs='+', default=[])\n", - "\n", - " # Misc\n", - " parser.add_argument('--device', type=int, default=0)\n", - " parser.add_argument('--seed', type=int, default=0)\n", - " parser.add_argument('--log_dir', default='./logs')\n", - " parser.add_argument('--log_every', default=50, type=int)\n", - " parser.add_argument('--save_step', type=int, default=None)\n", - " parser.add_argument('--save_best', action='store_true', default=False)\n", - " parser.add_argument('--save_last', action='store_true', default=False)\n", - " parser.add_argument('--save_outputs', action='store_true', default=False)\n", - " parser.add_argument('--no_group_logging', action='store_true', default=False)\n", - " parser.add_argument('--val_metric_decreasing', action='store_true', default=False)\n", - " parser.add_argument('--use_wandb', action='store_true', default=False)\n", - " parser.add_argument('--progress_bar', action='store_true', default=False)\n", - " parser.add_argument('--resume', default=False, action='store_true')\n", - " parser.add_argument('--eval_only', default=False, action='store_true')\n", - "\n", - " args = parser.parse_args()\n", - "\n", - " # set device\n", - " args.device = torch.device(\"cuda:\" + str(args.device)) if torch.cuda.is_available() else torch.device(\"cpu\")\n", - "\n", - " # Set defaults\n", - " if args.groupby_fields is None:\n", - " args.no_group_logging = True\n", - " if args.val_metric is None:\n", - " args.val_metric = dataset_attributes[args.dataset]['val_metric']\n", - "\n", - " ## Initialize logs\n", - " if os.path.exists(args.log_dir) and args.resume:\n", - " resume=True\n", - " mode='a'\n", - " else:\n", - " resume=False\n", - " mode='w'\n", - " if not os.path.exists(args.log_dir):\n", - " os.makedirs(args.log_dir)\n", - " logger = Logger(os.path.join(args.log_dir, 'log.txt'), mode)\n", - "\n", - " # Record args\n", - " log_args(args, logger)\n", - "\n", - " # Set random seed\n", - " set_seed(args.seed)\n", - "\n", - " # Data\n", - " full_dataset = dataset_attributes[args.dataset]['constructor'](\n", - " root_dir=args.root_dir,\n", - " download=args.download,\n", - " split_scheme=args.split_scheme,\n", - " **args.dataset_kwargs)\n", - "\n", - " # To implement data augmentation (i.e., have different transforms\n", - " # at training time vs. test time), modify these two lines:\n", - " train_transform = dataset_attributes[args.dataset]['transform'](args.model)\n", - " if dataset_attributes[args.dataset].get('eval_transform') is None:\n", - " eval_transform = dataset_attributes[args.dataset]['transform'](args.model)\n", - " else:\n", - " eval_transform = dataset_attributes[args.dataset]['eval_transform'](args.model)\n", - "\n", - " train_grouper = CombinatorialGrouper(\n", - " dataset=full_dataset,\n", - " groupby_fields=args.groupby_fields)\n", - "\n", - " datasets = defaultdict(dict)\n", - " for split in full_dataset.split_dict.keys():\n", - " if split=='train':\n", - " transform = train_transform\n", - " verbose = True\n", - " elif split == 'val':\n", - " transform = eval_transform\n", - " verbose = True\n", - " else:\n", - " transform = eval_transform\n", - " verbose = False\n", - " # Get subset\n", - " datasets[split]['dataset'] = full_dataset.get_subset(\n", - " split,\n", - " frac=args.frac,\n", - " transform=transform)\n", - "\n", - " # Get loader\n", - " shared_loader_kwargs = {\n", - " 'num_workers': args.num_workers,\n", - " 'pin_memory': not args.no_pin_memory,\n", - " 'batch_size': args.batch_size,\n", - " 'collate_fn': dataset_attributes[args.dataset]['collate']\n", - " }\n", - "\n", - " if split == 'train':\n", - " datasets[split]['loader'] = get_train_loader(\n", - " loader=args.train_loader,\n", - " dataset=datasets[split]['dataset'],\n", - " grouper=train_grouper,\n", - " train_loader_kwargs=args.train_loader_kwargs,\n", - " **shared_loader_kwargs)\n", - " else:\n", - " datasets[split]['loader'] = get_eval_loader(\n", - " loader=args.eval_loader,\n", - " dataset=datasets[split]['dataset'],\n", - " grouper=train_grouper,\n", - " **shared_loader_kwargs)\n", - "\n", - " # Set fields\n", - " datasets[split]['split'] = split\n", - " datasets[split]['name'] = full_dataset.split_names[split]\n", - " datasets[split]['verbose'] = verbose\n", - " # Loggers\n", - " # Loggers\n", - " datasets[split]['eval_logger'] = BatchLogger(\n", - " os.path.join(args.log_dir, f'{split}_eval.csv'), mode=mode, use_wandb=args.use_wandb)\n", - " datasets[split]['algo_logger'] = BatchLogger(\n", - " os.path.join(args.log_dir, f'{split}_algo.csv'), mode=mode, use_wandb=args.use_wandb)\n", - "\n", - " if args.use_wandb:\n", - " initialize_wandb(args)\n", - "\n", - " # Logging dataset info\n", - " if args.no_group_logging and full_dataset.is_classification and full_dataset.y_size==1:\n", - " log_grouper = CombinatorialGrouper(\n", - " dataset=full_dataset,\n", - " groupby_fields=['y'])\n", - " elif args.no_group_logging:\n", - " log_grouper = None\n", - " else:\n", - " log_grouper = train_grouper\n", - " log_group_data(args, datasets, log_grouper, logger)\n", - "\n", - " ## Initialize algorithm\n", - " algorithm = initialize_algorithm(args, datasets, train_grouper)\n", - "\n", - " if not args.eval_only:\n", - " ## Load saved results if resuming\n", - " resume_success = False\n", - " if resume:\n", - " save_path = os.path.join(args.log_dir, 'last_model.pth')\n", - " if not os.path.exists(save_path):\n", - " epochs = [\n", - " int(file.split('_')[0])\n", - " for file in os.listdir(args.log_dir) if file.endswith('.pth')]\n", - " if len(epochs) > 0:\n", - " latest_epoch = max(epochs)\n", - " save_path = os.path.join(args.log_dir, f'{latest_epoch}_model.pth')\n", - " try:\n", - " prev_epoch, best_val_metric = load(algorithm, save_path)\n", - " epoch_offset = prev_epoch + 1\n", - " logger.write(f'Resuming from epoch {epoch_offset} with best val metric {best_val_metric}')\n", - " resume_success = True\n", - " except FileNotFoundError:\n", - " pass\n", - "\n", - " if resume_success == False:\n", - " epoch_offset=0\n", - " best_val_metric=None\n", - "\n", - "\n", - " train(algorithm,\n", - " datasets,\n", - " logger,\n", - " args,\n", - " epoch_offset=epoch_offset,\n", - " best_val_metric=best_val_metric)\n", - " else:\n", - " best_model_path = os.path.join(args.log_dir, 'best_model.pth')\n", - " best_epoch, best_val_metric = load(algorithm, best_model_path)\n", - " evaluate(algorithm, datasets, best_epoch, logger)\n", - "\n", - " logger.close()\n", - " for split in datasets:\n", - " datasets[split]['eval_logger'].close()\n", - " datasets[split]['algo_logger'].close()\n", - "\n", - "if __name__=='__main__':\n", - " main()\n" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 2", - "language": "python", - "name": "python2" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 2 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.13" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/dataset_preprocessing/encode-tfbs/README.md b/dataset_preprocessing/encode-tfbs/README.md new file mode 100644 index 00000000..0be5fbd6 --- /dev/null +++ b/dataset_preprocessing/encode-tfbs/README.md @@ -0,0 +1,18 @@ +## ENCODE-TFBS-wilds feature generation and preprocessing + +#### Requirements +- pyBigWig + +#### Instructions + +1. Download the human genome sequence (hg19 assembly) in FASTA format from http://hgdownload.cse.ucsc.edu/goldenpath/hg19/bigZips/hg19.fa.gz into `SEQUENCE_PATH`. + +2. Run `python prep_sequence.py --seq_path SEQUENCE_PATH --output_dir OUTPUT_DIR` to write the fasta file found in `SEQUENCE_PATH` to a numpy array archive in `OUTPUT_DIR`. + +3. Download the accessibility data from the challenge. This consists of whole-genome DNase files in bigwig format (*.bw) from https://www.synapse.org/#!Synapse:syn6176233. + +4. Run `python prep_accessibility.py --input_dir INPUT_DIR --output_dir OUTPUT_DIR` to extract the bigwigs into numpy array archives, one per celltype. + +5. Download the labels from the challenge into a label directory created for this purpose: + - The training labels from https://www.synapse.org/#!Synapse:syn7413983 for the relevant transcription factor (e.g. https://www.synapse.org/#!Synapse:syn7415202 for the TF MAX). + - The validation labels from https://www.synapse.org/#!Synapse:syn8441154 for the relevant transcription factor (e.g. https://www.synapse.org/#!Synapse:syn8442103 for the TF MAX). diff --git a/dataset_preprocessing/encode-tfbs/prep_accessibility.py b/dataset_preprocessing/encode-tfbs/prep_accessibility.py new file mode 100644 index 00000000..9033224e --- /dev/null +++ b/dataset_preprocessing/encode-tfbs/prep_accessibility.py @@ -0,0 +1,41 @@ +import numpy, pandas +import pyBigWig + +from tqdm import tqdm + + +def generate_accessibility_archives(input_dir, output_dir): + dnases = {} + celltypes = ['A549', 'GM12878', 'H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562'] + + for ctype in celltypes:#glob.glob('dnase_bigwigs/*'): + itime = time.time() + # ctype = pth.split('/')[1].split('.')[1] + bw = pyBigWig.open("{}/DNASE.{}.fc.signal.bigwig".format(input_dir, ctype)) + chromsizes = bw.chroms() + print(ctype, time.time() - itime) + dn_dict = {} + for chrom in chromsizes: #chr_IDs: + x = bw.values(chrom, 0, chromsizes[chrom], numpy=True) + dn_dict[chrom] = np.nan_to_num(x).astype(np.float16) # half-precision makes things significantly smaller (less time to load) + print(chrom, time.time() - itime) + dnases[ctype] = dn_dict + + for ctype in dnases: + itime = time.time() + dn_dict = dnases[ctype] + + # Save as npz archive + np.savez_compressed('{}/{}_dnase'.format(output_dir, ctype), **dn_dict) + print("Saving npz archive for celltype {}. Time: {}".format(ctype, time.time() - itime)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--input_dir', required=True) + parser.add_argument('--output_dir', required=True) + args = parser.parse_args() + + generate_accessibility_archives( + input_dir=args.input_dir, + output_dir=args.output_dir) \ No newline at end of file diff --git a/dataset_preprocessing/encode-tfbs/prep_datasets.ipynb b/dataset_preprocessing/encode-tfbs/prep_datasets.ipynb new file mode 100644 index 00000000..4b1fdc10 --- /dev/null +++ b/dataset_preprocessing/encode-tfbs/prep_datasets.ipynb @@ -0,0 +1,279 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import prep_utils, scipy, numpy as np, time\n", + "from scipy import sparse\n", + "\n", + "# Human chromosome names\n", + "chr_IDs = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX']" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Sequence" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "62743362it [00:54, 1151676.47it/s]\n" + ] + } + ], + "source": [ + "a = prep_utils.read_fasta('sequence/hg19.genome.fa')\n", + "\n", + "kw_dict = {}\n", + "itime = time.time()\n", + "for chrom in chr_IDs:\n", + " seqstr = a[chrom]\n", + " kw_dict[chrom] = prep_utils.one_hot_encode(seqstr, alphabet=['A', 'C', 'G', 'T', 'N'])\n", + " print(chrom, time.time() - itime)\n", + "\n", + "# Save as npz archive; can take several (>20) minutes\n", + "print(\"Saving npz archive...\")\n", + "np.savez_compressed('codalab_archive/sequence', **kw_dict)\n", + "print(time.time() - itime)\n", + "\n", + "# # Save as npy arrays\n", + "# itime = time.time()\n", + "# for chrom in kw_dict:\n", + "# np.save('sequence/{}.npy'.format(chrom), kw_dict[chrom])\n", + "# print(chrom, time.time() - itime)\n", + "\n", + "npz_archive = np.load('codalab_archive/sequence.npz')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## DNase" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "liver 0.006468534469604492\n", + "chr1 8.260387659072876\n", + "chr1 13.276052474975586\n", + "chr10 17.844778299331665\n", + "chr10 25.784512758255005\n", + "chr11 30.30143165588379\n", + "chr11 33.256701707839966\n", + "chr12 37.791435956954956\n", + "chr12 40.85292291641235\n", + "chr13 44.619521141052246\n", + "chr13 47.792500495910645\n", + "chr14 51.4214243888855\n", + "chr14 53.6813702583313\n", + "chr15 56.946401834487915\n", + "chr15 59.10466551780701\n", + "chr16 61.939475774765015\n", + "chr16 63.999470472335815\n", + "chr17 66.63648653030396\n", + "chr17 68.4126443862915\n", + "chr18 71.05454993247986\n", + "chr18 72.90085673332214\n", + "chr19 74.78594756126404\n", + "chr19 76.80954170227051\n", + "chr2 85.25815343856812\n", + "chr2 95.36479425430298\n", + "chr20 97.74516272544861\n", + "chr20 99.27151441574097\n", + "chr21 100.82207584381104\n", + "chr21 103.02815318107605\n", + "chr22 104.63926863670349\n", + "chr22 106.02127361297607\n", + "chr3 112.71910071372986\n", + "chr3 117.30491018295288\n", + "chr4 123.77405095100403\n", + "chr4 128.67069339752197\n", + "chr5 134.89299392700195\n", + "chr5 138.83413815498352\n", + "chr6 144.83386087417603\n", + "chr6 149.115407705307\n", + "chr7 154.4929392337799\n", + "chr7 157.8094253540039\n", + "chr8 162.8749077320099\n", + "chr8 165.9331293106079\n", + "chr9 170.5435709953308\n", + "chr9 173.46287417411804\n", + "chrX 178.5410988330841\n", + "chrX 185.49569463729858\n", + "chrY 187.14469981193542\n", + "chrY 189.6306025981903\n", + "MCF-7 0.01819300651550293\n", + "chr1 8.266149282455444\n", + "chr1 13.86928129196167\n", + "chr10 18.216674327850342\n", + "chr10 20.975315809249878\n", + "chr11 25.302175998687744\n", + "chr11 34.40013885498047\n", + "chr12 38.70525503158569\n", + "chr12 41.59175777435303\n", + "chr13 45.130286693573\n", + "chr13 47.67305374145508\n", + "chr14 51.26033353805542\n", + "chr14 53.59153509140015\n", + "chr15 56.858047008514404\n", + "chr15 59.08759665489197\n", + "chr16 62.03992414474487\n", + "chr16 63.99170207977295\n", + "chr17 67.05595779418945\n", + "chr17 69.3644654750824\n", + "chr18 71.78018283843994\n", + "chr18 73.58044695854187\n", + "chr19 75.70175457000732\n", + "chr19 79.72573828697205\n", + "chr2 87.675612449646\n", + "chr2 92.91672372817993\n", + "chr20 95.51653027534485\n", + "chr20 96.88600373268127\n", + "chr21 98.43806076049805\n", + "chr21 103.25369572639465\n", + "chr22 104.84882092475891\n", + "chr22 106.21143817901611\n", + "chr3 112.67947244644165\n", + "chr3 116.70610451698303\n", + "chr4 122.56520342826843\n", + "chr4 126.52856135368347\n", + "chr5 132.38469552993774\n", + "chr5 136.28370690345764\n", + "chr6 141.5743978023529\n", + "chr6 145.10061717033386\n", + "chr7 150.44007444381714\n", + "chr7 155.55760312080383\n", + "chr8 160.3683557510376\n", + "chr8 163.43416213989258\n", + "chr9 167.90313267707825\n", + "chr9 172.0667405128479\n", + "chrX 176.69336795806885\n", + "chrX 181.83150935173035\n", + "K562 0.007167339324951172\n", + "chr1 8.471662998199463\n", + "chr1 13.464861631393433\n", + "chr10 17.858335494995117\n", + "chr10 20.700791835784912\n", + "chr11 25.168848276138306\n", + "chr11 28.01260733604431\n", + "chr12 32.38129758834839\n", + "chr12 35.250038385391235\n", + "chr13 38.72063398361206\n", + "chr13 43.30442762374878\n", + "chr14 46.55065989494324\n", + "chr14 51.87103271484375\n", + "chr15 55.08980083465576\n", + "chr15 57.35198903083801\n", + "chr16 60.444990396499634\n", + "chr16 62.56146717071533\n", + "chr17 65.33607196807861\n", + "chr17 75.77480912208557\n", + "chr18 78.25007915496826\n", + "chr18 82.4424319267273\n", + "chr19 84.73718905448914\n", + "chr19 86.0900673866272\n", + "chr2 93.6916708946228\n", + "chr2 98.61803960800171\n", + "chr20 100.70567536354065\n", + "chr20 102.18551921844482\n", + "chr21 103.75095820426941\n", + "chr21 104.96330642700195\n", + "chr22 106.666348695755\n", + "chr22 108.20869731903076\n", + "chr3 114.6058874130249\n", + "chr3 123.16646194458008\n", + "chr4 129.07538533210754\n", + "chr4 135.95439338684082\n", + "chr5 141.63543701171875\n", + "chr5 148.8255476951599\n", + "chr6 154.68585968017578\n", + "chr6 160.3087387084961\n", + "chr7 165.7410364151001\n", + "chr7 169.09255123138428\n", + "chr8 173.68864274024963\n", + "chr8 176.73100185394287\n", + "chr9 181.10383462905884\n", + "chr9 184.0267071723938\n", + "chrX 188.59823846817017\n", + "chrX 191.7538366317749\n" + ] + } + ], + "source": [ + "### import pyBigWig\n", + "import glob\n", + "\n", + "dnases = {}\n", + "celltypes = ['A549', 'GM12878', 'H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']\n", + "\n", + "for ctype in celltypes:#glob.glob('dnase_bigwigs/*'):\n", + " itime = time.time()\n", + " # ctype = pth.split('/')[1].split('.')[1]\n", + " if ctype not in ['liver', 'MCF-7', 'K562']:\n", + " continue\n", + " bw = pyBigWig.open(\"dnase_bigwigs/DNASE.{}.fc.signal.bigwig\".format(ctype))\n", + " chromsizes = bw.chroms()\n", + " print(ctype, time.time() - itime)\n", + " dn_dict = {}\n", + " for chrom in chromsizes: #chr_IDs:\n", + " x = bw.values(chrom, 0, chromsizes[chrom], numpy=True)\n", + " dn_dict[chrom] = np.nan_to_num(x).astype(np.float16) # half-precision makes things significantly smaller (less time to load)\n", + " print(chrom, time.time() - itime)\n", + " \n", + " np.save('dnase/{}/{}.npy'.format(ctype, chrom), dn_dict[chrom])\n", + " print(chrom, time.time() - itime)\n", + " dnases[ctype] = dn_dict\n", + "\n", + "for ctype in dnases:\n", + " itime = time.time()\n", + " print(ctype)\n", + " dn_dict = dnases[ctype]\n", + " \n", + " # Save as npz archive\n", + " np.savez_compressed('codalab_archive/{}_dnase'.format(ctype), **dn_dict)\n", + " print(time.time() - itime)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 2", + "language": "python", + "name": "python2" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/dataset_preprocessing/encode-tfbs/prep_sequence.py b/dataset_preprocessing/encode-tfbs/prep_sequence.py new file mode 100644 index 00000000..5a0baea5 --- /dev/null +++ b/dataset_preprocessing/encode-tfbs/prep_sequence.py @@ -0,0 +1,151 @@ +import argparse, time +import numpy, pandas + +from tqdm import tqdm + + +def one_hot_encode(sequence, ignore='N', alphabet=None, dtype='int8', + verbose=False, **kwargs): + """Converts a string or list of characters into a one-hot encoding. + This function will take in either a string or a list and convert it into a + one-hot encoding. If the input is a string, each character is assumed to be + a different symbol, e.g. 'ACGT' is assumed to be a sequence of four + characters. If the input is a list, the elements can be any size. + Although this function will be used here primarily to convert nucleotide + sequences into one-hot encoding with an alphabet of size 4, in principle + this function can be used for any types of sequences. + Parameters + ---------- + sequence : str or list + The sequence to convert to a one-hot encoding. + ignore : str, optional + A character to indicate setting nothing to 1 for that row, keeping the + encoding entirely 0's for that row. In the context of genomics, this is + the N character. Default is 'N'. + alphabet : set or tuple or list, optional + A pre-defined alphabet. If None is passed in, the alphabet will be + determined from the sequence, but this may be time consuming for + large sequences. Default is None. + dtype : str or numpy.dtype, optional + The data type of the returned encoding. Default is int8. + verbose : bool or str, optional + Whether to display a progress bar. If a string is passed in, use as the + name of the progressbar. Default is False. + kwargs : arguments + Arguments to be passed into tqdm. Default is None. + Returns + ------- + ohe : numpy.ndarray + A binary matrix of shape (alphabet_size, sequence_length) where + alphabet_size is the number of unique elements in the sequence and + sequence_length is the length of the input sequence. + """ + + name = None if verbose in (True, False) else verbose + d = verbose is False + + if isinstance(sequence, str): + sequence = list(sequence) + + alphabet = alphabet or numpy.unique(sequence) + alphabet = [char for char in alphabet if char != ignore] + alphabet_lookup = {char: i for i, char in enumerate(alphabet)} + + ohe = numpy.zeros((len(sequence), len(alphabet)), dtype=dtype) + for i, char in tqdm(enumerate(sequence), disable=d, desc=name, **kwargs): + if char != ignore: + idx = alphabet_lookup[char] + ohe[i, idx] = 1 + + return ohe + + +def read_fasta(filename, include_chroms=None, exclude_chroms=None, + ignore='N', alphabet=['A', 'C', 'G', 'T', 'N'], verbose=True): + """Read in a FASTA file and output a dictionary of sequences. + This function will take in the path to a FASTA-formatted file and output + a string containing the sequence for each chromosome. Optionally, + the user can specify a set of chromosomes to include or exclude from + the returned dictionary. + Parameters + ---------- + filename : str + The path to the FASTA-formatted file to open. + include_chroms : set or tuple or list, optional + The exact names of chromosomes in the FASTA file to include, excluding + all others. If None, include all chromosomes (except those specified by + exclude_chroms). Default is None. + exclude_chroms : set or tuple or list, optional + The exact names of chromosomes in the FASTA file to exclude, including + all others. If None, include all chromosomes (or the set specified by + include_chroms). Default is None. + ignore : str, optional + A character to indicate setting nothing to 1 for that row, keeping the + encoding entirely 0's for that row. In the context of genomics, this is + the N character. Default is 'N'. + alphabet : set or tuple or list, optional + A pre-defined alphabet. If None is passed in, the alphabet will be + determined from the sequence, but this may be time consuming for + large sequences. Must include the ignore character. Default is + ['A', 'C', 'G', 'T', 'N']. + verbose : bool or str, optional + Whether to display a progress bar. If a string is passed in, use as the + name of the progressbar. Default is False. + Returns + ------- + chroms : dict + A dictionary of strings where the keys are the names of the + chromosomes (exact strings from the header lines in the FASTA file) + and the values are the strings encoded there. + """ + + sequences = {} + name, sequence = None, None + skip_chrom = False + + with open(filename, "r") as infile: + for line in tqdm(infile, disable=not verbose): + if line.startswith(">"): + if name is not None and skip_chrom is False: + sequences[name] = ''.join(sequence) + + sequence = [] + name = line[1:].strip("\n") + if include_chroms is not None and name not in include_chroms: + skip_chrom = True + elif exclude_chroms is not None and name in exclude_chroms: + skip_chrom = True + else: + skip_chrom = False + + else: + if skip_chrom == False: + sequence.append(line.rstrip("\n").upper()) + + return sequences + + +def generate_sequence_archive(seq_path='sequence/hg19.genome.fa', output_dir): + fasta_contents = read_fasta() + kw_dict = {} + itime = time.time() + for chrom in chr_IDs: + seqstr = fasta_contents[chrom] + kw_dict[chrom] = one_hot_encode(seqstr, alphabet=['A', 'C', 'G', 'T', 'N']) + print(chrom, time.time() - itime) + + # Save as npz archive; can take several (>20) minutes + print("Saving npz archive...") + np.savez_compressed('{}/sequence'.format(output_root), **kw_dict) + print(time.time() - itime) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--seq_path', required=True) + parser.add_argument('--output_dir', required=True) + args = parser.parse_args() + + generate_sequence_archive( + seq_path=args.seq_path, + output_dir=args.output_dir) \ No newline at end of file diff --git a/dataset_preprocessing/encode-tfbs/sandbox_data.ipynb b/sandbox_data.ipynb similarity index 97% rename from dataset_preprocessing/encode-tfbs/sandbox_data.ipynb rename to sandbox_data.ipynb index b2e74829..55a67da4 100644 --- a/dataset_preprocessing/encode-tfbs/sandbox_data.ipynb +++ b/sandbox_data.ipynb @@ -16,8 +16,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "57.5289368629\n", - "65.2459537983\n" + "50.2965240479\n", + "58.1326179504\n" ] } ], @@ -73,37 +73,37 @@ "name": "stdout", "output_type": "stream", "text": [ - "0.0467748641968\n", - "('chr1', 4.52302885055542)\n", - "('chr2', 8.645489931106567)\n", - "('chr3', 11.959153890609741)\n", - "('chr4', 15.15813684463501)\n", - "('chr5', 18.22238802909851)\n", - "('chr6', 21.19420099258423)\n", - "('chr7', 23.940655946731567)\n", - "('chr8', 26.415233850479126)\n", - "('chr9', 28.833614826202393)\n", - "('chr10', 31.08920383453369)\n", - "('chr11', 33.37020301818848)\n", - "('chr12', 35.98973989486694)\n", - "('chr13', 37.88540601730347)\n", - "('chr14', 39.68082284927368)\n", - "('chr15', 41.242313861846924)\n", - "('chr16', 42.74874496459961)\n", - "('chr17', 44.12280797958374)\n", - "('chr18', 45.46893382072449)\n", - "('chr19', 46.50577902793884)\n", - "('chr20', 47.59563183784485)\n", - "('chr21', 48.31779384613037)\n", - "('chr22', 49.17265295982361)\n", - "('chrX', 51.75806999206543)\n", - "('H1-hESC', 25.880441904067993)\n", - "('HCT116', 50.130937814712524)\n", - "('HeLa-S3', 75.29559993743896)\n", - "('HepG2', 102.25979495048523)\n", - "('K562', 128.43050694465637)\n", - "('A549', 154.80679488182068)\n", - "('GM12878', 182.0279529094696)\n" + "1.40137600899\n", + "('chr1', 4.365410089492798)\n", + "('chr2', 8.54686713218689)\n", + "('chr3', 11.915641069412231)\n", + "('chr4', 15.147382020950317)\n", + "('chr5', 18.221237182617188)\n", + "('chr6', 21.16081714630127)\n", + "('chr7', 23.87936806678772)\n", + "('chr8', 26.382845163345337)\n", + "('chr9', 28.802964210510254)\n", + "('chr10', 31.10539698600769)\n", + "('chr11', 33.392733097076416)\n", + "('chr12', 35.6597261428833)\n", + "('chr13', 37.56297421455383)\n", + "('chr14', 39.363978147506714)\n", + "('chr15', 41.089357137680054)\n", + "('chr16', 42.6117000579834)\n", + "('chr17', 43.9806342124939)\n", + "('chr18', 45.29493808746338)\n", + "('chr19', 46.26894497871399)\n", + "('chr20', 47.31300115585327)\n", + "('chr21', 48.139018058776855)\n", + "('chr22', 48.97876214981079)\n", + "('chrX', 51.61549210548401)\n", + "('H1-hESC', 24.14024806022644)\n", + "('HCT116', 47.97159004211426)\n", + "('HeLa-S3', 72.82926392555237)\n", + "('HepG2', 97.18733406066895)\n", + "('K562', 121.94148206710815)\n", + "('A549', 147.29550194740295)\n", + "('GM12878', 171.71312499046326)\n" ] } ], @@ -118,6 +118,7 @@ "for chrom in seq_arr:\n", " _seq_bp[chrom] = seq_arr[chrom]\n", " print(chrom, time.time() - itime)\n", + "print(\"Sequence read. Time: {}\".format(time.time() - itime))\n", "\n", "itime = time.time()\n", "_dnase_allcelltypes = {}\n", @@ -127,9 +128,17 @@ " _dnase_allcelltypes[ct] = {}\n", " for chrom in _seq_bp:\n", " _dnase_allcelltypes[ct][chrom] = dnase_npz_file[chrom]\n", - " print(ct, time.time() - itime)" + " print(ct, time.time() - itime)\n", + "print(\"DNase read for all celltypes. Time: {}\".format(time.time() - itime))" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": 10, diff --git a/sandbox_model.ipynb b/sandbox_model.ipynb index 885c8a59..2d62b55e 100644 --- a/sandbox_model.ipynb +++ b/sandbox_model.ipynb @@ -105,226 +105,294 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class Beagle2(nn.Module):\n", + " \"\"\"\n", + " Neural net models over genomic sequence.\n", + " Input:\n", + " - sequence_length: int (default 1000) \n", + " - Shape: (N, 5, sequence_length, 1) with batch size N.\n", + " \n", + " Output:\n", + " - prediction (Tensor): float torch tensor of shape (N, )\n", + " \n", + " TODO: Finish docstring.\n", + " \"\"\"\n", + " def __init__(self):\n", + " \"\"\"\n", + " Parameters\n", + " ----------\n", + " sequence_length : int\n", + " n_genomic_features : int\n", + " \"\"\"\n", + " super(Beagle2, self).__init__()\n", + "\n", + " self.dropout = 0.3\n", + " self.num_cell_types = 1\n", + " self.conv1 = nn.Conv2d(5, 300, (19, 1), stride = (1, 1), padding=(9,0))\n", + " self.conv2 = nn.Conv2d(300, 200, (11, 1), stride = (1, 1), padding = (5,0))\n", + " self.conv3 = nn.Conv2d(200, 200, (7, 1), stride = (1, 1), padding = (4,0))\n", + " self.bn1 = nn.BatchNorm2d(300)\n", + " self.bn2 = nn.BatchNorm2d(200)\n", + " self.bn3 = nn.BatchNorm2d(200)\n", + " self.maxpool1 = nn.MaxPool2d((3, 1))\n", + " self.maxpool2 = nn.MaxPool2d((4, 1))\n", + " self.maxpool3 = nn.MaxPool2d((4, 1))\n", + "\n", + " self.fc1 = nn.Linear(4200, 1000)\n", + " self.bn4 = nn.BatchNorm1d(1000)\n", + "\n", + " self.fc2 = nn.Linear(1000, 1000)\n", + " self.bn5 = nn.BatchNorm1d(1000)\n", + "\n", + " self.fc3 = nn.Linear(1000, self.num_cell_types)\n", + "\n", + " def forward(self, s):\n", + " s = s.permute(0, 2, 1).contiguous() # batch_size x 4 x 1000\n", + " s = s.view(-1, 5, 1000, 1) # batch_size x 4 x 1000 x 1 [4 channels]\n", + " s = self.maxpool1(F.relu(self.bn1(self.conv1(s)))) # batch_size x 300 x 333 x 1\n", + " s = self.maxpool2(F.relu(self.bn2(self.conv2(s)))) # batch_size x 200 x 83 x 1\n", + " s = self.maxpool3(F.relu(self.bn3(self.conv3(s)))) # batch_size x 200 x 21 x 1\n", + " s = s.view(-1, 4200)\n", + " conv_out = s\n", + "\n", + " s = F.dropout(F.relu(self.bn4(self.fc1(s))), p=self.dropout, training=self.training) # batch_size x 1000\n", + " #s = F.dropout(F.relu(self.bn5(self.fc2(s))), p=self.dropout, training=self.training) # batch_size x 1000\n", + " \n", + " \n", + " s = self.fc3(s)\n", + "\n", + " return s, conv_out\n", + "\n", + "\n", + "class DanQ(nn.Module):\n", + " def __init__(self, sequence_length, n_genomic_features):\n", + " \"\"\"\n", + " Parameters\n", + " ----------\n", + " sequence_length : int\n", + " Input sequence length\n", + " n_genomic_features : int\n", + " Total number of features to predict\n", + " \"\"\"\n", + " super(DanQ, self).__init__()\n", + " self.nnet = nn.Sequential(\n", + " nn.Conv1d(4, 320, kernel_size=26),\n", + " nn.ReLU(inplace=True),\n", + " nn.MaxPool1d(\n", + " kernel_size=13, stride=13),\n", + " nn.Dropout(0.2))\n", + "\n", + " self.bdlstm = nn.Sequential(\n", + " nn.LSTM(\n", + " 320, 320, num_layers=1, batch_first=True, bidirectional=True))\n", + "\n", + " self._n_channels = math.floor(\n", + " (sequence_length - 25) / 13)\n", + " self.classifier = nn.Sequential(\n", + " nn.Dropout(0.5),\n", + " nn.Linear(self._n_channels * 640, 925),\n", + " nn.ReLU(inplace=True),\n", + " nn.Linear(925, n_genomic_features),\n", + " nn.Sigmoid())\n", + "\n", + " def forward(self, x):\n", + " \"\"\"Forward propagation of a batch.\n", + " \"\"\"\n", + " out = self.nnet(x)\n", + " reshape_out = out.transpose(0, 1).transpose(0, 2)\n", + " out, _ = self.bdlstm(reshape_out)\n", + " out = out.transpose(0, 1)\n", + " reshape_out = out.contiguous().view(\n", + " out.size(0), 640 * self._n_channels)\n", + " predict = self.classifier(reshape_out)\n", + " return predict\n", + "\n", + "\n", + "class DeepSEA(nn.Module):\n", + " def __init__(self, sequence_length, n_genomic_features):\n", + " \"\"\"\n", + " Parameters\n", + " ----------\n", + " sequence_length : int\n", + " n_genomic_features : int\n", + " \"\"\"\n", + " super(DeepSEA, self).__init__()\n", + " conv_kernel_size = 8\n", + " pool_kernel_size = 4\n", + "\n", + " self.conv_net = nn.Sequential(\n", + " nn.Conv1d(4, 320, kernel_size=conv_kernel_size),\n", + " nn.ReLU(inplace=True),\n", + " nn.MaxPool1d(\n", + " kernel_size=pool_kernel_size, stride=pool_kernel_size),\n", + " nn.Dropout(p=0.2),\n", + "\n", + " nn.Conv1d(320, 480, kernel_size=conv_kernel_size),\n", + " nn.ReLU(inplace=True),\n", + " nn.MaxPool1d(\n", + " kernel_size=pool_kernel_size, stride=pool_kernel_size),\n", + " nn.Dropout(p=0.2),\n", + "\n", + " nn.Conv1d(480, 960, kernel_size=conv_kernel_size),\n", + " nn.ReLU(inplace=True),\n", + " nn.Dropout(p=0.5))\n", + "\n", + " reduce_by = conv_kernel_size - 1\n", + " pool_kernel_size = float(pool_kernel_size)\n", + " self.n_channels = int(\n", + " np.floor(\n", + " (np.floor(\n", + " (sequence_length - reduce_by) / pool_kernel_size)\n", + " - reduce_by) / pool_kernel_size)\n", + " - reduce_by)\n", + " self.classifier = nn.Sequential(\n", + " nn.Linear(960 * self.n_channels, n_genomic_features),\n", + " nn.ReLU(inplace=True),\n", + " nn.Linear(n_genomic_features, n_genomic_features),\n", + " nn.Sigmoid())\n", + "\n", + " def forward(self, x):\n", + " \"\"\"Forward propagation of a batch.\n", + " \"\"\"\n", + " out = self.conv_net(x)\n", + " reshape_out = out.view(out.size(0), 960 * self.n_channels)\n", + " predict = self.classifier(reshape_out)\n", + " return predict" + ] + }, + { + "cell_type": "code", + "execution_count": 78, + "metadata": {}, + "outputs": [], + "source": [ + "import math\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "class Beagle(nn.Module):\n", + " \"\"\"\n", + " Neural net models over genomic sequence.\n", + " Input:\n", + " - sequence_length: int (default 1000) \n", + " - Shape: (N, 5, sequence_length, 1) with batch size N.\n", + " \n", + " Output:\n", + " - prediction (Tensor): float torch tensor of shape (N, )\n", + " \n", + " TODO: Finish docstring.\n", + " \"\"\"\n", + " def __init__(self):\n", + " \"\"\"\n", + " Parameters\n", + " ----------\n", + " sequence_length : int\n", + " n_genomic_features : int\n", + " \"\"\"\n", + " super(Beagle, self).__init__()\n", + "\n", + " self.dropout = 0.3\n", + " self.num_cell_types = 1\n", + " self.conv1 = nn.Conv2d(5, 300, (19, 1), stride = (1, 1), padding=(9,0))\n", + " self.conv2 = nn.Conv2d(300, 200, (11, 1), stride = (1, 1), padding = (5,0))\n", + " self.conv3 = nn.Conv2d(200, 200, (7, 1), stride = (1, 1), padding = (4,0))\n", + " self.bn1 = nn.BatchNorm2d(300)\n", + " self.bn2 = nn.BatchNorm2d(200)\n", + " self.bn3 = nn.BatchNorm2d(200)\n", + " self.maxpool1 = nn.MaxPool2d((3, 1))\n", + " self.maxpool2 = nn.MaxPool2d((4, 1))\n", + " self.maxpool3 = nn.MaxPool2d((4, 1))\n", + "\n", + " self.fc1 = nn.Linear(4200, 1000)\n", + " self.bn4 = nn.BatchNorm1d(1000)\n", + "\n", + " self.fc2 = nn.Linear(1000, 1000)\n", + " self.bn5 = nn.BatchNorm1d(1000)\n", + "\n", + " self.fc3 = nn.Linear(1000, self.num_cell_types)\n", + "\n", + " def forward(self, s):\n", + " s = s.permute(0, 2, 1).contiguous() # batch_size x 5 x 1000\n", + " s = s.view(-1, 5, 1000, 1) # batch_size x 5 x 1000 x 1 [5 channels]\n", + " s = self.maxpool1(F.relu(self.bn1(self.conv1(s)))) # batch_size x 300 x 333 x 1\n", + " s = self.maxpool2(F.relu(self.bn2(self.conv2(s)))) # batch_size x 200 x 83 x 1\n", + " s = self.maxpool3(F.relu(self.bn3(self.conv3(s)))) # batch_size x 200 x 21 x 1\n", + " s = s.view(-1, 4200)\n", + " conv_out = s\n", + "\n", + " s = F.dropout(F.relu(self.bn4(self.fc1(s))), p=self.dropout, training=self.training) # batch_size x 1000\n", + " s = F.dropout(F.relu(self.bn5(self.fc2(s))), p=self.dropout, training=self.training) # batch_size x 1000\n", + " \n", + " s = self.fc3(s)\n", + "\n", + " return s, conv_out" + ] + }, + { + "cell_type": "code", + "execution_count": 86, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "{'A549': {'chr1': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr10': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr11': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr12': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr13': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr14': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr15': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr16': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr17': array([ 0.35986328, 0.35986328, 0.35986328, ..., 0. ,\n", - " 0. , 0. ], dtype=float16),\n", - " 'chr18': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr19': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr2': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr20': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr21': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr22': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr3': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr4': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr5': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr6': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr7': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr8': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr9': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chrX': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16)},\n", - " 'GM12878': {'chr1': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr10': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr11': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr12': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr13': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr14': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr15': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr16': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr17': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr18': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr19': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr2': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr20': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr21': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr22': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr3': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr4': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr5': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr6': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr7': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr8': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr9': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chrX': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16)},\n", - " 'H1-hESC': {'chr1': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr10': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr11': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr12': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr13': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr14': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr15': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr16': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr17': array([ 0.71972656, 0.71972656, 0.71972656, ..., 0. ,\n", - " 0. , 0. ], dtype=float16),\n", - " 'chr18': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr19': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr2': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr20': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr21': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr22': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr3': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr4': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr5': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr6': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr7': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr8': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr9': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chrX': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16)},\n", - " 'HCT116': {'chr1': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr10': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr11': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr12': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr13': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr14': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr15': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr16': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr17': array([ 0.80419922, 0.80419922, 0.80419922, ..., 0. ,\n", - " 0. , 0. ], dtype=float16),\n", - " 'chr18': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr19': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr2': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr20': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr21': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr22': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr3': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr4': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr5': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr6': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr7': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr8': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr9': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chrX': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16)},\n", - " 'HeLa-S3': {'chr1': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr10': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr11': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr12': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr13': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr14': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr15': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr16': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr17': array([ 0.71972656, 0.71972656, 0.71972656, ..., 0. ,\n", - " 0. , 0. ], dtype=float16),\n", - " 'chr18': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr19': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr2': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr20': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr21': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr22': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr3': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr4': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr5': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr6': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr7': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr8': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr9': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chrX': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16)},\n", - " 'HepG2': {'chr1': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr10': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr11': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr12': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr13': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr14': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr15': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr16': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr17': array([ 0.71972656, 0.71972656, 0.71972656, ..., 0. ,\n", - " 0. , 0. ], dtype=float16),\n", - " 'chr18': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr19': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr2': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr20': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr21': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr22': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr3': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr4': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr5': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr6': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr7': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr8': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr9': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chrX': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16)},\n", - " 'K562': {'chr1': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr10': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr11': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr12': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr13': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr14': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr15': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr16': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr17': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr18': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr19': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr2': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr20': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr21': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr22': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr3': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr4': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr5': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr6': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr7': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr8': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chr9': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16),\n", - " 'chrX': array([ 0., 0., 0., ..., 0., 0., 0.], dtype=float16)}}" + "[('nnet.0.weight', 33280),\n", + " ('nnet.0.bias', 320),\n", + " ('bdlstm.0.weight_ih_l0', 409600),\n", + " ('bdlstm.0.weight_hh_l0', 409600),\n", + " ('bdlstm.0.bias_ih_l0', 1280),\n", + " ('bdlstm.0.bias_hh_l0', 1280),\n", + " ('bdlstm.0.weight_ih_l0_reverse', 409600),\n", + " ('bdlstm.0.weight_hh_l0_reverse', 409600),\n", + " ('bdlstm.0.bias_ih_l0_reverse', 1280),\n", + " ('bdlstm.0.bias_hh_l0_reverse', 1280),\n", + " ('classifier.1.weight', 592000),\n", + " ('classifier.1.bias', 925),\n", + " ('classifier.3.weight', 4625),\n", + " ('classifier.3.bias', 5)]" ] }, - "execution_count": 5, + "execution_count": 86, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "_dnase_allcelltypes" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "from examples.models import CNN_genome" + "def count_parameters(model):\n", + " return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", + "\n", + "model = Beagle2()\n", + "model = DanQ(50, 5)\n", + "\n", + "lst = [(x[0], x[1].numel()) for x in model.named_parameters()]\n", + "#np.sum([x[1] for x in lst])\n", + "count_parameters(model)\n", + "lst" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 48, "metadata": {}, "outputs": [ { - "ename": "TypeError", - "evalue": "unbound method parameters() must be called with Beagle instance as first argument (got nothing instead)", + "ename": "AttributeError", + "evalue": "'module' object has no attribute 'isin'", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# def count_parameters(model):\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;31m# return sum(p.numel() for p in model.parameters() if p.requires_grad)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mCNN_genome\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mBeagle\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparameters\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;31mTypeError\u001b[0m: unbound method parameters() must be called with Beagle instance as first argument (got nothing instead)" + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mtr_chrs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m'chr2'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr9'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr11'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mte_chrs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m'chr1'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr8'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr21'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mtraining_df\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'chr'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtr_chrs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mval_df\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mval_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mval_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'chr'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mte_chrs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mall_df\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtraining_df\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_df\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mAttributeError\u001b[0m: 'module' object has no attribute 'isin'" ] } ], - "source": [ - "# def count_parameters(model):\n", - "# return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", - "CNN_genome.Beagle.parameters()" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [], "source": [ "tr_chrs = ['chr2', 'chr9', 'chr11']\n", "te_chrs = ['chr1', 'chr8', 'chr21']\n", @@ -337,6 +405,23 @@ "all_df = all_df[filter_msk]" ] }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1.12.1\n" + ] + } + ], + "source": [ + "print(np.__version__)" + ] + }, { "cell_type": "code", "execution_count": 30, @@ -528,17 +613,6 @@ "print(time.time() - itime)" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "python3 examples/run_expt.py -d camelyon17 --algorithm ERM --model densenet121 --split_scheme standard --groupby_fields hospital --loss_function cross_entropy --optimizer SGD \n", - "--lr 0.0001 --batch_size 32 --weight_decay 0 --n_epochs 10 --scheduler ReduceLROnPlateau --scheduler_metric_split val --scheduler_metric_name acc_avg \n", - "--log_dir log --log_every 50 --save_step 1000 --save_best --save_last --seed 0 --evaluate_all_splits --root_dir ROOTDIR" - ] - }, { "cell_type": "code", "execution_count": 156, From a86b32c26a61d692cd69ea0c900459036445549a Mon Sep 17 00:00:00 2001 From: aikanor Date: Mon, 8 Feb 2021 08:51:43 -0800 Subject: [PATCH 054/244] final integration 1/ --- .../encode-tfbs/prep_accessibility.py | 2 ++ .../encode-tfbs/prep_sequence.py | 2 ++ examples/models/CNN_genome.py | 2 +- sandbox_data.ipynb | 24 +++++++++++++++++++ 4 files changed, 29 insertions(+), 1 deletion(-) diff --git a/dataset_preprocessing/encode-tfbs/prep_accessibility.py b/dataset_preprocessing/encode-tfbs/prep_accessibility.py index 9033224e..7342f797 100644 --- a/dataset_preprocessing/encode-tfbs/prep_accessibility.py +++ b/dataset_preprocessing/encode-tfbs/prep_accessibility.py @@ -3,6 +3,8 @@ from tqdm import tqdm +# Human chromosome names +chr_IDs = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX'] def generate_accessibility_archives(input_dir, output_dir): dnases = {} diff --git a/dataset_preprocessing/encode-tfbs/prep_sequence.py b/dataset_preprocessing/encode-tfbs/prep_sequence.py index 5a0baea5..7f396d9f 100644 --- a/dataset_preprocessing/encode-tfbs/prep_sequence.py +++ b/dataset_preprocessing/encode-tfbs/prep_sequence.py @@ -3,6 +3,8 @@ from tqdm import tqdm +# Human chromosome names +chr_IDs = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX'] def one_hot_encode(sequence, ignore='N', alphabet=None, dtype='int8', verbose=False, **kwargs): diff --git a/examples/models/CNN_genome.py b/examples/models/CNN_genome.py index 75295cd3..8a658eab 100644 --- a/examples/models/CNN_genome.py +++ b/examples/models/CNN_genome.py @@ -59,4 +59,4 @@ def forward(self, s): s = self.fc3(s) - return s, conv_out + return s#, conv_out diff --git a/sandbox_data.ipynb b/sandbox_data.ipynb index 55a67da4..ad5ae4bd 100644 --- a/sandbox_data.ipynb +++ b/sandbox_data.ipynb @@ -1,5 +1,29 @@ { "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- examples\n", + " - run_expt.py\n", + " - configs\n", + " - [x] supported.py\n", + " - [ ] model.py\n", + " - [ ] datasets.py\n", + " - models\n", + " - [x] CNN_genome.py\n", + " - train.py\n", + " - utils.py\n", + "- wilds\n", + " - [x] datasets/encodetfbs_dataset.py\n", + " - common\n", + " - metrics\n", + " - [ ] all_metrics.py\n", + " - data_loaders.py\n", + " - grouper.py\n", + " - [ ] utils.py ( threshold_at_recall() )" + ] + }, { "cell_type": "markdown", "metadata": {}, From d78f2a3da696192b18a25bfa20bc6bfde717c22b Mon Sep 17 00:00:00 2001 From: aikanor Date: Mon, 8 Feb 2021 16:45:18 -0800 Subject: [PATCH 055/244] integration 2/ --- examples/configs/model.py | 3 +- sandbox_data.ipynb | 387 +-------------------------- wilds/datasets/encodetfbs_dataset.py | 5 +- 3 files changed, 11 insertions(+), 384 deletions(-) diff --git a/examples/configs/model.py b/examples/configs/model.py index 46714bbe..af37ab8e 100644 --- a/examples/configs/model.py +++ b/examples/configs/model.py @@ -36,5 +36,6 @@ 'resnet18_ms': { 'target_resolution': (224, 224), }, - 'logistic_regression': {}, + 'logistic_regression': {}, + 'beagle': {}, } diff --git a/sandbox_data.ipynb b/sandbox_data.ipynb index ad5ae4bd..0a9806a6 100644 --- a/sandbox_data.ipynb +++ b/sandbox_data.ipynb @@ -24,6 +24,13 @@ " - [ ] utils.py ( threshold_at_recall() )" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "markdown", "metadata": {}, @@ -579,386 +586,6 @@ } ], "source": [] - }, - { - "cell_type": "code", - "execution_count": 165, - "metadata": {}, - "outputs": [ - { - "ename": "ModuleNotFoundError", - "evalue": "No module named 'torch_scatter'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minsert\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'..'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 16\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mwilds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata_loaders\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mget_train_loader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mget_eval_loader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 17\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mwilds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrouper\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mCombinatorialGrouper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mwilds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mget_counts\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/dr_benchmark/wilds/common/data_loaders.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mDataLoader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msampler\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mWeightedRandomSampler\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mSubsetRandomSampler\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mwilds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mget_counts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msplit_into_groups\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mIPython\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/dr_benchmark/wilds/common/utils.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch_scatter\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnumpy\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mSubset\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mpandas\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapi\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtypes\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mCategoricalDtype\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'torch_scatter'" - ] - } - ], - "source": [ - "import os, csv\n", - "import time\n", - "import argparse\n", - "import IPython\n", - "import pandas as pd\n", - "import torch\n", - "import torch.nn as nn\n", - "import torchvision\n", - "import sys\n", - "from collections import defaultdict\n", - "# torch.multiprocessing.set_sharing_strategy('file_system')\n", - "\n", - "# TODO: Replace this once we make wilds into an installed package\n", - "sys.path.insert(1, os.path.join(sys.path[0], '..'))\n", - "\n", - "from wilds.common.data_loaders import get_train_loader, get_eval_loader\n", - "from wilds.common.grouper import CombinatorialGrouper\n", - "from wilds.common.utils import get_counts\n", - "\n", - "from models.model_attributes import model_attributes\n", - "from utils import set_seed, Logger, BatchLogger, log_args, ParseKwargs, load\n", - "from train import train, evaluate\n", - "from data import dataset_attributes\n", - "from optimizer import optimizer_attributes\n", - "from scheduler import scheduler_attributes\n", - "from loss import losses\n", - "from utils import log_group_data\n", - "from algorithms.constructors import algorithm_constructors" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "from examples.models.model_attributes import model_attributes" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "ename": "ModuleNotFoundError", - "evalue": "No module named 'utils'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel_attributes\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mmodel_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mset_seed\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mLogger\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mCSVBatchLogger\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlog_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mParseKwargs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mload\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 21\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 22\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mdataset_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0moptimizer_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/dr_benchmark/examples/train.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0msave\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mIPython\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'utils'" - ] - } - ], - "source": [ - "def initialize_algorithm(args, datasets, train_grouper):\n", - " train_dataset = datasets['train']['dataset']\n", - " train_loader = datasets['train']['loader']\n", - "\n", - " # Configure the final layer of the networks used\n", - " # The code below are defaults. Edit this if you need special config for your model.\n", - " if (train_dataset.is_classification) and (train_dataset.y_size == 1):\n", - " # For single-task classification, we have one output per class\n", - " d_out = train_dataset.n_classes\n", - " elif (train_dataset.is_classification) and (train_dataset.y_size > 1) and (train_dataset.n_classes == 2):\n", - " # For multi-task binary classification (each output is the logit for each binary class)\n", - " d_out = train_dataset.y_size\n", - " elif (not train_dataset.is_classification):\n", - " # For regression, we have one output per target dimension\n", - " d_out = train_dataset.y_size\n", - " else:\n", - " raise RuntimeError('d_out not defined.')\n", - " \n", - "\n", - " # Sanity checking input args\n", - " if args.algorithm == 'groupDRO':\n", - " assert args.train_loader_kwargs['uniform_over_groups']\n", - " elif args.algorithm in ['deepCORAL', 'IRM']:\n", - " assert args.train_loader == 'group'\n", - " assert args.train_loader_kwargs['uniform_over_groups']\n", - " assert args.train_loader_kwargs['distinct_groups']\n", - "\n", - " # Other config\n", - " n_train_steps = len(train_loader) * args.n_epochs\n", - "# prediction_fn = dataset_attributes[args.dataset]['prediction_fn']\n", - " loss = losses[args.loss_function]\n", - " metric = dataset_attributes[args.dataset]['metric']\n", - " train_g = train_grouper.metadata_to_group(train_dataset.metadata_array)\n", - " is_group_in_train = get_counts(train_g, train_grouper.n_groups) > 0\n", - " algorithm_constructor = algorithm_constructors[args.algorithm]\n", - " algorithm = algorithm_constructor(\n", - " args=args,\n", - " d_out=d_out,\n", - " grouper=train_grouper,\n", - " loss=loss,\n", - " metric=metric,\n", - " n_train_steps=n_train_steps,\n", - " is_group_in_train=is_group_in_train)\n", - " return algorithm" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def main():\n", - " parser = argparse.ArgumentParser()\n", - "\n", - " # Dataset\n", - " parser.add_argument('-d', '--dataset', choices=dataset_attributes.keys(), required=True)\n", - " parser.add_argument('--split_scheme', default='standard',\n", - " help='Identifies how the train/val/test split is constructed. Choices are dataset-specific.')\n", - " parser.add_argument('--dataset_kwargs', nargs='*', action=ParseKwargs, default={})\n", - " parser.add_argument('--root_dir', default=None, required=True,\n", - " help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).')\n", - " parser.add_argument('--download', default=False, action='store_true',\n", - " help='If true, tries to downloads the dataset if it does not exist in root_dir.')\n", - " parser.add_argument('--frac', type=float, default=1.0,\n", - " help='Convenience parameter that scales all dataset splits down to the specified fraction, for development purposes.')\n", - "\n", - " # Loaders\n", - " parser.add_argument('--train_loader', choices=['standard', 'group'], default='standard')\n", - " parser.add_argument('--train_loader_kwargs', nargs='*', action=ParseKwargs, default={})\n", - " parser.add_argument('--eval_loader', choices=['standard', 'group'], default='standard')\n", - " parser.add_argument('--batch_size', type=int, default=32)\n", - " parser.add_argument('--no_pin_memory', action='store_true') # TODO: put as loader_kwargs\n", - " parser.add_argument('--num_workers', type=int, default=4) # TODO: put as loader kwargs\n", - "\n", - " # Model\n", - " parser.add_argument(\n", - " '--model',\n", - " choices=model_attributes.keys(),\n", - " default='resnet50')\n", - " parser.add_argument('--model_kwargs', nargs='*', action=ParseKwargs, default={},\n", - " help='keyword arguments for model initialization passed as key1=value1 key2=value2')\n", - " parser.add_argument('--train_from_scratch', action='store_true', default=False)\n", - "\n", - " # Algorithm and objective\n", - " parser.add_argument('--algorithm', required=True, choices=algorithm_constructors.keys())\n", - " parser.add_argument('--algorithm_kwargs', nargs='*', action=ParseKwargs, default={})\n", - " parser.add_argument('--groupby_fields', nargs='+', default=None)\n", - " parser.add_argument('--loss_function', required=True, choices = losses.keys()) #TODO: make default\n", - " parser.add_argument('--val_metric', default=None)\n", - "\n", - " # Optimization\n", - " parser.add_argument('--n_epochs', type=int, default=4)\n", - " parser.add_argument('--optimizer', default=None, choices=optimizer_attributes.keys())\n", - " parser.add_argument('--lr', type=float, required=True)\n", - " parser.add_argument('--weight_decay', type=float, required=True)\n", - " parser.add_argument('--optimizer_kwargs', nargs='*', action=ParseKwargs, default={})\n", - " parser.add_argument('--scheduler', default=None, choices=scheduler_attributes.keys())\n", - " parser.add_argument('--scheduler_kwargs', nargs='*', action=ParseKwargs, default={})\n", - " parser.add_argument('--scheduler_metric_split', choices=['train', 'val'], default='val')\n", - " parser.add_argument('--scheduler_metric_name')\n", - "\n", - " # Evaluation\n", - " parser.add_argument('--evaluate_all_splits', action='store_true', default=False)\n", - " parser.add_argument('--additional_eval_splits', nargs='+', default=[])\n", - "\n", - " # Misc\n", - " parser.add_argument('--device', type=int, default=0)\n", - " parser.add_argument('--seed', type=int, default=0)\n", - " parser.add_argument('--log_dir', default='./logs')\n", - " parser.add_argument('--log_every', default=50, type=int)\n", - " parser.add_argument('--save_step', type=int, default=None)\n", - " parser.add_argument('--save_best', action='store_true', default=False)\n", - " parser.add_argument('--save_last', action='store_true', default=False)\n", - " parser.add_argument('--save_outputs', action='store_true', default=False)\n", - " parser.add_argument('--no_group_logging', action='store_true', default=False)\n", - " parser.add_argument('--val_metric_decreasing', action='store_true', default=False)\n", - " parser.add_argument('--use_wandb', action='store_true', default=False)\n", - " parser.add_argument('--progress_bar', action='store_true', default=False)\n", - " parser.add_argument('--resume', default=False, action='store_true')\n", - " parser.add_argument('--eval_only', default=False, action='store_true')\n", - "\n", - " args = parser.parse_args()\n", - "\n", - " # set device\n", - " args.device = torch.device(\"cuda:\" + str(args.device)) if torch.cuda.is_available() else torch.device(\"cpu\")\n", - "\n", - " # Set defaults\n", - " if args.groupby_fields is None:\n", - " args.no_group_logging = True\n", - " if args.val_metric is None:\n", - " args.val_metric = dataset_attributes[args.dataset]['val_metric']\n", - "\n", - " ## Initialize logs\n", - " if os.path.exists(args.log_dir) and args.resume:\n", - " resume=True\n", - " mode='a'\n", - " else:\n", - " resume=False\n", - " mode='w'\n", - " if not os.path.exists(args.log_dir):\n", - " os.makedirs(args.log_dir)\n", - " logger = Logger(os.path.join(args.log_dir, 'log.txt'), mode)\n", - "\n", - " # Record args\n", - " log_args(args, logger)\n", - "\n", - " # Set random seed\n", - " set_seed(args.seed)\n", - "\n", - " # Data\n", - " full_dataset = dataset_attributes[args.dataset]['constructor'](\n", - " root_dir=args.root_dir,\n", - " download=args.download,\n", - " split_scheme=args.split_scheme,\n", - " **args.dataset_kwargs)\n", - "\n", - " # To implement data augmentation (i.e., have different transforms\n", - " # at training time vs. test time), modify these two lines:\n", - " train_transform = dataset_attributes[args.dataset]['transform'](args.model)\n", - " if dataset_attributes[args.dataset].get('eval_transform') is None:\n", - " eval_transform = dataset_attributes[args.dataset]['transform'](args.model)\n", - " else:\n", - " eval_transform = dataset_attributes[args.dataset]['eval_transform'](args.model)\n", - "\n", - " train_grouper = CombinatorialGrouper(\n", - " dataset=full_dataset,\n", - " groupby_fields=args.groupby_fields)\n", - "\n", - " datasets = defaultdict(dict)\n", - " for split in full_dataset.split_dict.keys():\n", - " if split=='train':\n", - " transform = train_transform\n", - " verbose = True\n", - " elif split == 'val':\n", - " transform = eval_transform\n", - " verbose = True\n", - " else:\n", - " transform = eval_transform\n", - " verbose = False\n", - " # Get subset\n", - " datasets[split]['dataset'] = full_dataset.get_subset(\n", - " split,\n", - " frac=args.frac,\n", - " transform=transform)\n", - "\n", - " # Get loader\n", - " shared_loader_kwargs = {\n", - " 'num_workers': args.num_workers,\n", - " 'pin_memory': not args.no_pin_memory,\n", - " 'batch_size': args.batch_size,\n", - " 'collate_fn': dataset_attributes[args.dataset]['collate']\n", - " }\n", - "\n", - " if split == 'train':\n", - " datasets[split]['loader'] = get_train_loader(\n", - " loader=args.train_loader,\n", - " dataset=datasets[split]['dataset'],\n", - " grouper=train_grouper,\n", - " train_loader_kwargs=args.train_loader_kwargs,\n", - " **shared_loader_kwargs)\n", - " else:\n", - " datasets[split]['loader'] = get_eval_loader(\n", - " loader=args.eval_loader,\n", - " dataset=datasets[split]['dataset'],\n", - " grouper=train_grouper,\n", - " **shared_loader_kwargs)\n", - "\n", - " # Set fields\n", - " datasets[split]['split'] = split\n", - " datasets[split]['name'] = full_dataset.split_names[split]\n", - " datasets[split]['verbose'] = verbose\n", - " # Loggers\n", - " # Loggers\n", - " datasets[split]['eval_logger'] = BatchLogger(\n", - " os.path.join(args.log_dir, f'{split}_eval.csv'), mode=mode, use_wandb=args.use_wandb)\n", - " datasets[split]['algo_logger'] = BatchLogger(\n", - " os.path.join(args.log_dir, f'{split}_algo.csv'), mode=mode, use_wandb=args.use_wandb)\n", - "\n", - " if args.use_wandb:\n", - " initialize_wandb(args)\n", - "\n", - " # Logging dataset info\n", - " if args.no_group_logging and full_dataset.is_classification and full_dataset.y_size==1:\n", - " log_grouper = CombinatorialGrouper(\n", - " dataset=full_dataset,\n", - " groupby_fields=['y'])\n", - " elif args.no_group_logging:\n", - " log_grouper = None\n", - " else:\n", - " log_grouper = train_grouper\n", - " log_group_data(args, datasets, log_grouper, logger)\n", - "\n", - " ## Initialize algorithm\n", - " algorithm = initialize_algorithm(args, datasets, train_grouper)\n", - "\n", - " if not args.eval_only:\n", - " ## Load saved results if resuming\n", - " resume_success = False\n", - " if resume:\n", - " save_path = os.path.join(args.log_dir, 'last_model.pth')\n", - " if not os.path.exists(save_path):\n", - " epochs = [\n", - " int(file.split('_')[0])\n", - " for file in os.listdir(args.log_dir) if file.endswith('.pth')]\n", - " if len(epochs) > 0:\n", - " latest_epoch = max(epochs)\n", - " save_path = os.path.join(args.log_dir, f'{latest_epoch}_model.pth')\n", - " try:\n", - " prev_epoch, best_val_metric = load(algorithm, save_path)\n", - " epoch_offset = prev_epoch + 1\n", - " logger.write(f'Resuming from epoch {epoch_offset} with best val metric {best_val_metric}')\n", - " resume_success = True\n", - " except FileNotFoundError:\n", - " pass\n", - "\n", - " if resume_success == False:\n", - " epoch_offset=0\n", - " best_val_metric=None\n", - "\n", - "\n", - " train(algorithm,\n", - " datasets,\n", - " logger,\n", - " args,\n", - " epoch_offset=epoch_offset,\n", - " best_val_metric=best_val_metric)\n", - " else:\n", - " best_model_path = os.path.join(args.log_dir, 'best_model.pth')\n", - " best_epoch, best_val_metric = load(algorithm, best_model_path)\n", - " evaluate(algorithm, datasets, best_epoch, logger)\n", - "\n", - " logger.close()\n", - " for split in datasets:\n", - " datasets[split]['eval_logger'].close()\n", - " datasets[split]['algo_logger'].close()\n", - "\n", - "if __name__=='__main__':\n", - " main()\n" - ] } ], "metadata": { diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index 08276aa9..6996cc15 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -5,7 +5,6 @@ from wilds.datasets.wilds_dataset import WILDSDataset from wilds.common.grouper import CombinatorialGrouper from wilds.common.metrics.eval_metric import Accuracy -from wilds.common.eval import standard_group_eval import IPython @@ -133,7 +132,7 @@ def get_input(self, idx): (3) Metadata for the index (location along the genome with 1kb window width) """ this_metadata = self._metadata_df.iloc[idx, :] - flank_size = 500 + flank_size = 400 interval_start = this_metadata['start'] - flank_size interval_end = this_metadata['stop'] + flank_size dnase_this = _dnase_allcelltypes[this_metadata['celltype']][this_metadata['chr']][interval_start:interval_end] @@ -141,7 +140,7 @@ def get_input(self, idx): return np.column_stack([seq_this, dnase_this]) def eval(self, y_pred, y_true, metadata): - return standard_group_eval( + return self.standard_group_eval( self._metric, self._eval_grouper, y_pred, y_true, metadata) From b3b6f626cadf7d85f4a06974245c0b4c7e49c4e7 Mon Sep 17 00:00:00 2001 From: aikanor Date: Mon, 8 Feb 2021 17:40:41 -0800 Subject: [PATCH 056/244] integration 3/ --- sbox_run_expt.ipynb | 1032 ++++++++++++++++++++++++++ wilds/datasets/encodetfbs_dataset.py | 4 +- 2 files changed, 1033 insertions(+), 3 deletions(-) create mode 100644 sbox_run_expt.ipynb diff --git a/sbox_run_expt.ipynb b/sbox_run_expt.ipynb new file mode 100644 index 00000000..612397ce --- /dev/null +++ b/sbox_run_expt.ipynb @@ -0,0 +1,1032 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# run_expt.py contents" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "ename": "SyntaxError", + "evalue": "invalid syntax (version.py, line 20)", + "output_type": "error", + "traceback": [ + "\u001b[0;36m File \u001b[0;32m\"wilds/version.py\"\u001b[0;36m, line \u001b[0;32m20\u001b[0m\n\u001b[0;31m f'The WILDS package is out of date. Your version is {__version__}, while the latest version is {latest}.')\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m invalid syntax\n" + ] + } + ], + "source": [ + "import os, csv\n", + "import time\n", + "import argparse\n", + "import pandas as pd\n", + "import torch\n", + "import torch.nn as nn\n", + "import torchvision\n", + "import sys\n", + "from collections import defaultdict\n", + "\n", + "from wilds.common.data_loaders import get_train_loader, get_eval_loader\n", + "from wilds.common.grouper import CombinatorialGrouper\n", + "\n", + "from utils import set_seed, Logger, BatchLogger, log_config, ParseKwargs, load, initialize_wandb, log_group_data, parse_bool\n", + "from train import train, evaluate\n", + "from algorithms.initializer import initialize_algorithm\n", + "from transforms import initialize_transform\n", + "from configs.utils import populate_defaults\n", + "import configs.supported as supported" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Initialize dataset object" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "57.8772239685\n", + "66.8270189762\n" + ] + } + ], + "source": [ + "import numpy as np, pandas as pd, os, time, torch, torchvision\n", + "data_dir = '/oak/stanford/groups/akundaje/abalsubr/DREAM/wilds/codalab_archive/'\n", + "tf = 'MAX'\n", + "itime = time.time()\n", + "train_chr = pd.read_csv(os.path.join(data_dir, 'labels/{}.train.labels.tsv.gz'.format(tf)), sep='\\t')\n", + "print(time.time() - itime)\n", + "val_chr = pd.read_csv(os.path.join(data_dir, 'labels/{}.val.labels.tsv.gz'.format(tf)), sep='\\t')\n", + "print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']\n", + "val_celltype = ['A549']\n", + "test_celltype = ['GM12878']\n", + "all_celltypes = train_celltypes + val_celltype + test_celltype\n", + "\n", + "metadata_map = {}\n", + "metadata_map['chr'] = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX']\n", + "metadata_map['celltype'] = all_celltypes\n", + "\n", + "_split_dict = {\n", + " 'train': 0,\n", + " 'val-id': 1,\n", + " 'test': 2,\n", + " 'val-ood': 3\n", + "}\n", + "_split_names = {\n", + " 'train': 'Train',\n", + " 'val-id': 'Validation (ID)',\n", + " 'test': 'Test',\n", + " 'val-ood': 'Validation (OOD)'\n", + "}\n", + "_split_scheme = 'standard'" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "('H1-hESC', 25.299736976623535)\n", + "('HCT116', 49.68733310699463)\n", + "('HeLa-S3', 74.65905213356018)\n", + "('HepG2', 99.33112812042236)\n", + "('K562', 124.1327919960022)\n", + "('A549', 149.19999814033508)\n", + "('GM12878', 174.0277030467987)\n" + ] + } + ], + "source": [ + "itime = time.time()\n", + "sequence_filename = os.path.join(data_dir, 'sequence.npz')\n", + "seq_arr = np.load(sequence_filename)\n", + "print(time.time() - itime)\n", + "\n", + "itime = time.time()\n", + "_seq_bp = {}\n", + "for chrom in seq_arr:\n", + " _seq_bp[chrom] = seq_arr[chrom]\n", + " print(chrom, time.time() - itime)\n", + "itime = time.time()\n", + "_dnase_allcelltypes = {}\n", + "for ct in all_celltypes:\n", + " dnase_filename = os.path.join(data_dir, '{}_dnase.npz'.format(ct))\n", + " dnase_npz_file = np.load(dnase_filename)\n", + " _dnase_allcelltypes[ct] = {}\n", + " for chrom in _seq_bp:\n", + " _dnase_allcelltypes[ct][chrom] = dnase_npz_file[chrom]\n", + " print(ct, time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class Beagle2(nn.Module):\n", + " \"\"\"\n", + " Neural net models over genomic sequence.\n", + " Input:\n", + " - sequence_length: int (default 1000) \n", + " - Shape: (N, 5, sequence_length, 1) with batch size N.\n", + " \n", + " Output:\n", + " - prediction (Tensor): float torch tensor of shape (N, )\n", + " \n", + " TODO: Finish docstring.\n", + " \"\"\"\n", + " def __init__(self):\n", + " \"\"\"\n", + " Parameters\n", + " ----------\n", + " sequence_length : int\n", + " n_genomic_features : int\n", + " \"\"\"\n", + " super(Beagle2, self).__init__()\n", + "\n", + " self.dropout = 0.3\n", + " self.num_cell_types = 1\n", + " self.conv1 = nn.Conv2d(5, 300, (19, 1), stride = (1, 1), padding=(9,0))\n", + " self.conv2 = nn.Conv2d(300, 200, (11, 1), stride = (1, 1), padding = (5,0))\n", + " self.conv3 = nn.Conv2d(200, 200, (7, 1), stride = (1, 1), padding = (4,0))\n", + " self.bn1 = nn.BatchNorm2d(300)\n", + " self.bn2 = nn.BatchNorm2d(200)\n", + " self.bn3 = nn.BatchNorm2d(200)\n", + " self.maxpool1 = nn.MaxPool2d((3, 1))\n", + " self.maxpool2 = nn.MaxPool2d((4, 1))\n", + " self.maxpool3 = nn.MaxPool2d((4, 1))\n", + "\n", + " self.fc1 = nn.Linear(4200, 1000)\n", + " self.bn4 = nn.BatchNorm1d(1000)\n", + "\n", + " self.fc2 = nn.Linear(1000, 1000)\n", + " self.bn5 = nn.BatchNorm1d(1000)\n", + "\n", + " self.fc3 = nn.Linear(1000, self.num_cell_types)\n", + "\n", + " def forward(self, s):\n", + " s = s.permute(0, 2, 1).contiguous() # batch_size x 4 x 1000\n", + " s = s.view(-1, 5, 1000, 1) # batch_size x 4 x 1000 x 1 [4 channels]\n", + " s = self.maxpool1(F.relu(self.bn1(self.conv1(s)))) # batch_size x 300 x 333 x 1\n", + " s = self.maxpool2(F.relu(self.bn2(self.conv2(s)))) # batch_size x 200 x 83 x 1\n", + " s = self.maxpool3(F.relu(self.bn3(self.conv3(s)))) # batch_size x 200 x 21 x 1\n", + " s = s.view(-1, 4200)\n", + " conv_out = s\n", + "\n", + " s = F.dropout(F.relu(self.bn4(self.fc1(s))), p=self.dropout, training=self.training) # batch_size x 1000\n", + " #s = F.dropout(F.relu(self.bn5(self.fc2(s))), p=self.dropout, training=self.training) # batch_size x 1000\n", + " \n", + " \n", + " s = self.fc3(s)\n", + "\n", + " return s, conv_out\n", + "\n", + "\n", + "class DanQ(nn.Module):\n", + " def __init__(self, sequence_length, n_genomic_features):\n", + " \"\"\"\n", + " Parameters\n", + " ----------\n", + " sequence_length : int\n", + " Input sequence length\n", + " n_genomic_features : int\n", + " Total number of features to predict\n", + " \"\"\"\n", + " super(DanQ, self).__init__()\n", + " self.nnet = nn.Sequential(\n", + " nn.Conv1d(4, 320, kernel_size=26),\n", + " nn.ReLU(inplace=True),\n", + " nn.MaxPool1d(\n", + " kernel_size=13, stride=13),\n", + " nn.Dropout(0.2))\n", + "\n", + " self.bdlstm = nn.Sequential(\n", + " nn.LSTM(\n", + " 320, 320, num_layers=1, batch_first=True, bidirectional=True))\n", + "\n", + " self._n_channels = math.floor(\n", + " (sequence_length - 25) / 13)\n", + " self.classifier = nn.Sequential(\n", + " nn.Dropout(0.5),\n", + " nn.Linear(self._n_channels * 640, 925),\n", + " nn.ReLU(inplace=True),\n", + " nn.Linear(925, n_genomic_features),\n", + " nn.Sigmoid())\n", + "\n", + " def forward(self, x):\n", + " \"\"\"Forward propagation of a batch.\n", + " \"\"\"\n", + " out = self.nnet(x)\n", + " reshape_out = out.transpose(0, 1).transpose(0, 2)\n", + " out, _ = self.bdlstm(reshape_out)\n", + " out = out.transpose(0, 1)\n", + " reshape_out = out.contiguous().view(\n", + " out.size(0), 640 * self._n_channels)\n", + " predict = self.classifier(reshape_out)\n", + " return predict\n", + "\n", + "\n", + "class DeepSEA(nn.Module):\n", + " def __init__(self, sequence_length, n_genomic_features):\n", + " \"\"\"\n", + " Parameters\n", + " ----------\n", + " sequence_length : int\n", + " n_genomic_features : int\n", + " \"\"\"\n", + " super(DeepSEA, self).__init__()\n", + " conv_kernel_size = 8\n", + " pool_kernel_size = 4\n", + "\n", + " self.conv_net = nn.Sequential(\n", + " nn.Conv1d(4, 320, kernel_size=conv_kernel_size),\n", + " nn.ReLU(inplace=True),\n", + " nn.MaxPool1d(\n", + " kernel_size=pool_kernel_size, stride=pool_kernel_size),\n", + " nn.Dropout(p=0.2),\n", + "\n", + " nn.Conv1d(320, 480, kernel_size=conv_kernel_size),\n", + " nn.ReLU(inplace=True),\n", + " nn.MaxPool1d(\n", + " kernel_size=pool_kernel_size, stride=pool_kernel_size),\n", + " nn.Dropout(p=0.2),\n", + "\n", + " nn.Conv1d(480, 960, kernel_size=conv_kernel_size),\n", + " nn.ReLU(inplace=True),\n", + " nn.Dropout(p=0.5))\n", + "\n", + " reduce_by = conv_kernel_size - 1\n", + " pool_kernel_size = float(pool_kernel_size)\n", + " self.n_channels = int(\n", + " np.floor(\n", + " (np.floor(\n", + " (sequence_length - reduce_by) / pool_kernel_size)\n", + " - reduce_by) / pool_kernel_size)\n", + " - reduce_by)\n", + " self.classifier = nn.Sequential(\n", + " nn.Linear(960 * self.n_channels, n_genomic_features),\n", + " nn.ReLU(inplace=True),\n", + " nn.Linear(n_genomic_features, n_genomic_features),\n", + " nn.Sigmoid())\n", + "\n", + " def forward(self, x):\n", + " \"\"\"Forward propagation of a batch.\n", + " \"\"\"\n", + " out = self.conv_net(x)\n", + " reshape_out = out.view(out.size(0), 960 * self.n_channels)\n", + " predict = self.classifier(reshape_out)\n", + " return predict" + ] + }, + { + "cell_type": "code", + "execution_count": 78, + "metadata": {}, + "outputs": [], + "source": [ + "import math\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "class Beagle(nn.Module):\n", + " \"\"\"\n", + " Neural net models over genomic sequence.\n", + " Input:\n", + " - sequence_length: int (default 1000) \n", + " - Shape: (N, 5, sequence_length, 1) with batch size N.\n", + " \n", + " Output:\n", + " - prediction (Tensor): float torch tensor of shape (N, )\n", + " \n", + " TODO: Finish docstring.\n", + " \"\"\"\n", + " def __init__(self):\n", + " \"\"\"\n", + " Parameters\n", + " ----------\n", + " sequence_length : int\n", + " n_genomic_features : int\n", + " \"\"\"\n", + " super(Beagle, self).__init__()\n", + "\n", + " self.dropout = 0.3\n", + " self.num_cell_types = 1\n", + " self.conv1 = nn.Conv2d(5, 300, (19, 1), stride = (1, 1), padding=(9,0))\n", + " self.conv2 = nn.Conv2d(300, 200, (11, 1), stride = (1, 1), padding = (5,0))\n", + " self.conv3 = nn.Conv2d(200, 200, (7, 1), stride = (1, 1), padding = (4,0))\n", + " self.bn1 = nn.BatchNorm2d(300)\n", + " self.bn2 = nn.BatchNorm2d(200)\n", + " self.bn3 = nn.BatchNorm2d(200)\n", + " self.maxpool1 = nn.MaxPool2d((3, 1))\n", + " self.maxpool2 = nn.MaxPool2d((4, 1))\n", + " self.maxpool3 = nn.MaxPool2d((4, 1))\n", + "\n", + " self.fc1 = nn.Linear(4200, 1000)\n", + " self.bn4 = nn.BatchNorm1d(1000)\n", + "\n", + " self.fc2 = nn.Linear(1000, 1000)\n", + " self.bn5 = nn.BatchNorm1d(1000)\n", + "\n", + " self.fc3 = nn.Linear(1000, self.num_cell_types)\n", + "\n", + " def forward(self, s):\n", + " s = s.permute(0, 2, 1).contiguous() # batch_size x 5 x 1000\n", + " s = s.view(-1, 5, 1000, 1) # batch_size x 5 x 1000 x 1 [5 channels]\n", + " s = self.maxpool1(F.relu(self.bn1(self.conv1(s)))) # batch_size x 300 x 333 x 1\n", + " s = self.maxpool2(F.relu(self.bn2(self.conv2(s)))) # batch_size x 200 x 83 x 1\n", + " s = self.maxpool3(F.relu(self.bn3(self.conv3(s)))) # batch_size x 200 x 21 x 1\n", + " s = s.view(-1, 4200)\n", + " conv_out = s\n", + "\n", + " s = F.dropout(F.relu(self.bn4(self.fc1(s))), p=self.dropout, training=self.training) # batch_size x 1000\n", + " s = F.dropout(F.relu(self.bn5(self.fc2(s))), p=self.dropout, training=self.training) # batch_size x 1000\n", + " \n", + " s = self.fc3(s)\n", + "\n", + " return s, conv_out" + ] + }, + { + "cell_type": "code", + "execution_count": 86, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[('nnet.0.weight', 33280),\n", + " ('nnet.0.bias', 320),\n", + " ('bdlstm.0.weight_ih_l0', 409600),\n", + " ('bdlstm.0.weight_hh_l0', 409600),\n", + " ('bdlstm.0.bias_ih_l0', 1280),\n", + " ('bdlstm.0.bias_hh_l0', 1280),\n", + " ('bdlstm.0.weight_ih_l0_reverse', 409600),\n", + " ('bdlstm.0.weight_hh_l0_reverse', 409600),\n", + " ('bdlstm.0.bias_ih_l0_reverse', 1280),\n", + " ('bdlstm.0.bias_hh_l0_reverse', 1280),\n", + " ('classifier.1.weight', 592000),\n", + " ('classifier.1.bias', 925),\n", + " ('classifier.3.weight', 4625),\n", + " ('classifier.3.bias', 5)]" + ] + }, + "execution_count": 86, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def count_parameters(model):\n", + " return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", + "\n", + "model = Beagle2()\n", + "model = DanQ(50, 5)\n", + "\n", + "lst = [(x[0], x[1].numel()) for x in model.named_parameters()]\n", + "#np.sum([x[1] for x in lst])\n", + "count_parameters(model)\n", + "lst" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": {}, + "outputs": [ + { + "ename": "AttributeError", + "evalue": "'module' object has no attribute 'isin'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mtr_chrs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m'chr2'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr9'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr11'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mte_chrs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m'chr1'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr8'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr21'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mtraining_df\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'chr'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtr_chrs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mval_df\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mval_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mval_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'chr'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mte_chrs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mall_df\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtraining_df\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_df\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mAttributeError\u001b[0m: 'module' object has no attribute 'isin'" + ] + } + ], + "source": [ + "tr_chrs = ['chr2', 'chr9', 'chr11']\n", + "te_chrs = ['chr1', 'chr8', 'chr21']\n", + "training_df = train_chr[np.isin(train_chr['chr'], tr_chrs)]\n", + "val_df = val_chr[np.isin(val_chr['chr'], te_chrs)]\n", + "all_df = pd.concat([training_df, val_df])\n", + "\n", + "#filter_msk = all_df['start'] >= 0\n", + "filter_msk = all_df['start']%1000 == 0\n", + "all_df = all_df[filter_msk]" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1.12.1\n" + ] + } + ], + "source": [ + "print(np.__version__)" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/users/abalsubr/anaconda2/envs/scs3/lib/python3.6/site-packages/ipykernel_launcher.py:6: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy\n", + " \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1.659163236618042\n" + ] + } + ], + "source": [ + "itime = time.time()\n", + "pd_list = []\n", + "for ct in all_celltypes:\n", + " tc_chr = all_df[['chr', 'start', 'stop', ct]]\n", + " tc_chr.columns = ['chr', 'start', 'stop', 'y']\n", + " tc_chr['celltype'] = ct\n", + " pd_list.append(tc_chr)\n", + "metadata_df = pd.concat(pd_list)\n", + "print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3.0391879081726074\n" + ] + } + ], + "source": [ + "itime = time.time()\n", + "y_array = metadata_df['y'].replace({'U': 0, 'B': 1, 'A': -1}).values\n", + "non_ambig_mask = (y_array != -1)\n", + "metadata_df['y'] = y_array\n", + "_metadata_df = metadata_df[non_ambig_mask]\n", + "_y_array = torch.LongTensor(y_array[non_ambig_mask])\n", + "print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "12.390011310577393\n" + ] + } + ], + "source": [ + "itime = time.time()\n", + "chr_ints = _metadata_df['chr'].replace(dict( [(y, x) for x, y in enumerate(metadata_map['chr'])] )).values\n", + "celltype_ints = _metadata_df['celltype'].replace(dict( [(y, x) for x, y in enumerate(metadata_map['celltype'])] )).values\n", + "print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/users/abalsubr/anaconda2/envs/scs3/lib/python3.6/site-packages/ipykernel_launcher.py:12: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy\n", + " if sys.path[0] == '':\n" + ] + } + ], + "source": [ + "train_chr_mask = np.isin(_metadata_df['chr'], tr_chrs)\n", + "val_chr_mask = np.isin(_metadata_df['chr'], te_chrs)\n", + "train_celltype_mask = np.isin(_metadata_df['celltype'], train_celltypes)\n", + "val_celltype_mask = np.isin(_metadata_df['celltype'], val_celltype)\n", + "test_celltype_mask = np.isin(_metadata_df['celltype'], test_celltype)\n", + "\n", + "split_array = -1*np.ones(_metadata_df.shape[0]).astype(int)\n", + "split_array[np.logical_and(train_chr_mask, train_celltype_mask)] = _split_dict['train']\n", + "split_array[np.logical_and(val_chr_mask, test_celltype_mask)] = _split_dict['test']\n", + "split_array[np.logical_and(val_chr_mask, val_celltype_mask)] = _split_dict['val-ood']\n", + "split_array[np.logical_and(val_chr_mask, train_celltype_mask)] = _split_dict['val-id']\n", + "_metadata_df['split'] = split_array\n", + "_split_array = split_array" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# get_input (idx)" + ] + }, + { + "cell_type": "code", + "execution_count": 153, + "metadata": {}, + "outputs": [], + "source": [ + "idx = 3\n", + "this_metadata = _metadata_df.iloc[idx, :]\n", + "\n", + "itime = time.time()\n", + "flank_size = 400\n", + "interval_start = this_metadata['start'] - flank_size\n", + "interval_end = this_metadata['stop'] + flank_size\n", + "dnase_this = _dnase_allcelltypes[this_metadata['celltype']][this_metadata['chr']][interval_start:interval_end]\n", + "seq_this = _seq_bp[this_metadata['chr']][interval_start:interval_end]\n", + "data = np.column_stack([seq_this, dnase_this])\n", + "# print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 154, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "4600" + ] + }, + "execution_count": 154, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data.shape\n", + "interval_end\n", + "# itime = time.time()\n", + "# np.save(os.path.join(data_dir, 'stmp.npy'), sa)\n", + "# print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 78, + "metadata": {}, + "outputs": [ + { + "ename": "TypeError", + "evalue": "can't convert np.ndarray of type numpy.object_. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mitime\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m metadata_array = torch.stack(\n\u001b[0;32m----> 3\u001b[0;31m (torch.LongTensor(metadata_df['chr'].values), \n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mLongTensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmetadata_df\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'celltype'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m self._y_array),\n", + "\u001b[0;31mTypeError\u001b[0m: can't convert np.ndarray of type numpy.object_. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool." + ] + } + ], + "source": [ + "itime = time.time()\n", + "metadata_array = torch.stack(\n", + " (torch.LongTensor(chr_ints), \n", + " torch.LongTensor(celltype_ints), \n", + " _y_array),\n", + " dim=1)\n", + "print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 156, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name '_metadata_array' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0m_metadata_array\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mNameError\u001b[0m: name '_metadata_array' is not defined" + ] + } + ], + "source": [ + "_metadata_array" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from examples.models.model_attributes import model_attributes" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "ename": "ModuleNotFoundError", + "evalue": "No module named 'utils'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel_attributes\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mmodel_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mset_seed\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mLogger\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mCSVBatchLogger\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlog_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mParseKwargs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mload\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 21\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 22\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mdataset_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0moptimizer_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/dr_benchmark/examples/train.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0msave\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mIPython\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'utils'" + ] + } + ], + "source": [ + "import os, csv\n", + "import time\n", + "import argparse\n", + "import IPython\n", + "import pandas as pd\n", + "import torch\n", + "import torch.nn as nn\n", + "import torchvision\n", + "import sys\n", + "from collections import defaultdict\n", + "\n", + "# TODO: Replace this once we make wilds into an installed package\n", + "sys.path.insert(1, os.path.join(sys.path[0], '..'))\n", + "\n", + "from wilds.common.data_loaders import get_train_loader, get_eval_loader\n", + "from wilds.common.grouper import CombinatorialGrouper\n", + "from wilds.common.utils import get_counts\n", + "\n", + "from examples.models.model_attributes import model_attributes\n", + "from examples.utils import set_seed, Logger, CSVBatchLogger, log_args, ParseKwargs, load\n", + "from examples.train import train\n", + "from examples.data import dataset_attributes\n", + "from examples.optimizer import optimizer_attributes\n", + "from examples.scheduler import scheduler_attributes\n", + "from examples.loss import losses\n", + "from examples.utils import log_group_data\n", + "from examples.algorithms.constructors import algorithm_constructors\n", + "\n", + "\n", + "def initialize_algorithm(args, datasets, train_grouper):\n", + " train_dataset = datasets['train']['dataset']\n", + " train_loader = datasets['train']['loader']\n", + "\n", + " # Configure the final layer of the networks used\n", + " # The code below are defaults. Edit this if you need special config for your model.\n", + " if (train_dataset.is_classification) and (train_dataset.y_size == 1):\n", + " # For single-task classification, we have one output per class\n", + " d_out = train_dataset.n_classes\n", + " elif (not train_dataset.is_classification):\n", + " # For regression, we have one output per target dimension\n", + " d_out = train_dataset.y_size\n", + " else:\n", + " # TODO: Handle dataset-specific multi-task stuff here, e.g., for OGB\n", + " pass\n", + "\n", + " # Sanity checking input args\n", + " if args.algorithm == 'groupDRO':\n", + " assert args.train_loader_kwargs['uniform_over_groups']\n", + " elif args.algorithm in ['deepCORAL', 'IRM']:\n", + " assert args.train_loader == 'group'\n", + " assert args.train_loader_kwargs['uniform_over_groups']\n", + " assert args.train_loader_kwargs['distinct_groups']\n", + "\n", + " # Other config\n", + " n_train_steps = len(train_loader) * args.n_epochs\n", + " prediction_fn = dataset_attributes[args.dataset]['prediction_fn']\n", + " loss = losses[args.loss_function]\n", + " metric_constructor = dataset_attributes[args.dataset]['metric']\n", + " train_g = train_grouper.metadata_to_group(train_dataset.metadata_array)\n", + " is_group_in_train = get_counts(train_g, train_grouper.n_groups) > 0\n", + " algorithm_constructor = algorithm_constructors[args.algorithm]\n", + " algorithm = algorithm_constructor(\n", + " args=args,\n", + " d_out=d_out,\n", + " grouper=train_grouper,\n", + " prediction_fn=prediction_fn,\n", + " loss=loss,\n", + " metric_constructor=metric_constructor,\n", + " n_train_steps=n_train_steps,\n", + " is_group_in_train=is_group_in_train)\n", + " return algorithm" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "parser = argparse.ArgumentParser()\n", + "\n", + "# Dataset\n", + "parser.add_argument('-d', '--dataset', choices=dataset_attributes.keys(), required=True)\n", + "parser.add_argument('--split_scheme', default='standard',\n", + " help='Identifies how the train/val/test split is constructed. Choices are dataset-specific.')\n", + "parser.add_argument('--root_dir', default=None, required=True,\n", + " help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).')\n", + "parser.add_argument('--download', default=False, action='store_true',\n", + " help='If true, tries to downloads the dataset if it does not exist in root_dir.')\n", + "parser.add_argument('--frac', type=float, default=1.0,\n", + " help='Convenience parameter that scales all dataset splits down to the specified fraction, for development purposes.')\n", + "\n", + "# Loaders\n", + "parser.add_argument('--train_loader', choices=['standard', 'group'], default='standard')\n", + "parser.add_argument('--train_loader_kwargs', nargs='*', action=ParseKwargs, default={})\n", + "parser.add_argument('--eval_loader', choices=['standard', 'group'], default='standard')\n", + "parser.add_argument('--batch_size', type=int, default=32)\n", + "\n", + "# Model\n", + "parser.add_argument(\n", + " '--model',\n", + " choices=model_attributes.keys(),\n", + " default='resnet50')\n", + "parser.add_argument('--model_kwargs', nargs='*', action=ParseKwargs, default={},\n", + " help='keyword arguments for model initialization passed as key1=value1 key2=value2')\n", + "parser.add_argument('--train_from_scratch', action='store_true', default=False)\n", + "\n", + "# Algorithm and objective\n", + "parser.add_argument('--algorithm', required=True, choices=algorithm_constructors.keys())\n", + "parser.add_argument('--algorithm_kwargs', nargs='*', action=ParseKwargs, default={})\n", + "parser.add_argument('--groupby_fields', nargs='+', default=None)\n", + "parser.add_argument('--loss_function', required=True, choices = losses.keys()) #TODO: make default\n", + "parser.add_argument('--val_metric', default=None)\n", + "\n", + "# Optimization\n", + "parser.add_argument('--n_epochs', type=int, default=4)\n", + "parser.add_argument('--optimizer', default=None, choices=optimizer_attributes.keys())\n", + "parser.add_argument('--lr', type=float, required=True)\n", + "parser.add_argument('--weight_decay', type=float, required=True)\n", + "parser.add_argument('--optimizer_kwargs', nargs='*', action=ParseKwargs, default={})\n", + "parser.add_argument('--scheduler', default=None, choices=scheduler_attributes.keys())\n", + "parser.add_argument('--scheduler_kwargs', nargs='*', action=ParseKwargs, default={})\n", + "parser.add_argument('--scheduler_metric_split', choices=['train', 'val'], default='val')\n", + "parser.add_argument('--scheduler_metric_name')\n", + "\n", + "# Evaluation\n", + "parser.add_argument('--evaluate_all_splits', action='store_true', default=False)\n", + "parser.add_argument('--additional_eval_splits', nargs='+', default=[])\n", + "\n", + "# Misc\n", + "parser.add_argument('--device', default='cuda')\n", + "parser.add_argument('--seed', type=int, default=0)\n", + "parser.add_argument('--log_dir', default='./logs')\n", + "parser.add_argument('--log_every', default=50, type=int)\n", + "parser.add_argument('--save_step', type=int, default=None)\n", + "parser.add_argument('--save_best', action='store_true', default=False)\n", + "parser.add_argument('--save_last', action='store_true', default=False)\n", + "parser.add_argument('--save_outputs', action='store_true', default=False)\n", + "parser.add_argument('--no_group_logging', action='store_true', default=False)\n", + "\n", + "parser.add_argument('--resume', default=False, action='store_true')\n", + "\n", + "args = parser.parse_args()\n", + "\n", + "# Set defaults\n", + "if args.groupby_fields is None:\n", + " args.no_group_logging = True\n", + "if args.val_metric is None:\n", + " args.val_metric = dataset_attributes[args.dataset]['val_metric']\n", + "\n", + "## Initialize logs\n", + "if os.path.exists(args.log_dir) and args.resume:\n", + " resume=True\n", + " mode='a'\n", + "else:\n", + " resume=False\n", + " mode='w'\n", + "if not os.path.exists(args.log_dir):\n", + " os.makedirs(args.log_dir)\n", + "logger = Logger(os.path.join(args.log_dir, 'log.txt'), mode)\n", + "\n", + "# Record args\n", + "log_args(args, logger)\n", + "\n", + "# Set random seed\n", + "set_seed(args.seed)\n", + "\n", + "# Data\n", + "full_dataset = dataset_attributes[args.dataset]['constructor'](\n", + " root_dir=args.root_dir,\n", + " download=args.download,\n", + " split_scheme=args.split_scheme)\n", + "\n", + "# To implement data augmentation (i.e., have different transforms\n", + "# at training time vs. test time), modify these two lines:\n", + "train_transform = dataset_attributes[args.dataset]['transform'](args.model)\n", + "eval_transform = dataset_attributes[args.dataset]['transform'](args.model)\n", + "\n", + "train_grouper = CombinatorialGrouper(\n", + " dataset=full_dataset,\n", + " groupby_fields=args.groupby_fields)\n", + "\n", + "datasets = defaultdict(dict)\n", + "for split in full_dataset.split_dict.keys():\n", + " if split=='train':\n", + " transform = train_transform\n", + " verbose = True\n", + " elif split == 'val':\n", + " transform = eval_transform\n", + " verbose = True\n", + " else:\n", + " transform = eval_transform\n", + " verbose = False\n", + " # Get subset\n", + " datasets[split]['dataset'] = full_dataset.get_subset(\n", + " split,\n", + " frac=args.frac,\n", + " transform=transform)\n", + "\n", + " # Get loader\n", + " shared_loader_kwargs = {\n", + " 'num_workers': 4,\n", + " 'pin_memory': True,\n", + " 'batch_size': args.batch_size,\n", + " 'collate_fn': dataset_attributes[args.dataset]['collate']\n", + " }\n", + "\n", + " if split == 'train':\n", + " datasets[split]['loader'] = get_train_loader(\n", + " loader=args.train_loader,\n", + " dataset=datasets[split]['dataset'],\n", + " grouper=train_grouper,\n", + " train_loader_kwargs=args.train_loader_kwargs,\n", + " **shared_loader_kwargs)\n", + " else:\n", + " datasets[split]['loader'] = get_eval_loader(\n", + " loader=args.eval_loader,\n", + " dataset=datasets[split]['dataset'],\n", + " grouper=train_grouper,\n", + " **shared_loader_kwargs)\n", + "\n", + " # Set fields\n", + " datasets[split]['split'] = split\n", + " datasets[split]['name'] = full_dataset.split_names[split]\n", + " datasets[split]['verbose'] = verbose\n", + " # Loggers\n", + " datasets[split]['eval_logger'] = CSVBatchLogger(\n", + " os.path.join(args.log_dir, f'{split}_eval.csv'), mode=mode)\n", + " datasets[split]['algo_logger'] = CSVBatchLogger(\n", + " os.path.join(args.log_dir, f'{split}_algo.csv'), mode=mode)\n", + "\n", + "# Logging dataset info\n", + "if args.no_group_logging and full_dataset.is_classification and full_dataset.y_size==1:\n", + " log_grouper = CombinatorialGrouper(\n", + " dataset=full_dataset,\n", + " groupby_fields=['y'])\n", + "elif args.no_group_logging:\n", + " log_grouper = None\n", + "else:\n", + " log_grouper = train_grouper\n", + "log_group_data(args, datasets, log_grouper, logger)\n", + "\n", + "## Initialize algorithm\n", + "algorithm = initialize_algorithm(args, datasets, train_grouper)\n", + "\n", + "## Load saved results if resuming\n", + "if resume:\n", + " save_path = os.path.join(args.log_dir, 'last_model.pth')\n", + " prev_epoch, best_val_metric = load(algorithm, save_path)\n", + " epoch_offset = prev_epoch + 1\n", + "else:\n", + " epoch_offset=0\n", + " best_val_metric=None\n", + "\n", + "train(algorithm,\n", + " datasets,\n", + " logger,\n", + " args,\n", + " epoch_offset=epoch_offset,\n", + " best_val_metric=best_val_metric)\n", + "\n", + "logger.close()\n", + "for split in datasets:\n", + " datasets[split]['eval_logger'].close()\n", + " datasets[split]['algo_logger'].close()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 2", + "language": "python", + "name": "python2" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index 6996cc15..062e468a 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -4,9 +4,7 @@ import numpy as np from wilds.datasets.wilds_dataset import WILDSDataset from wilds.common.grouper import CombinatorialGrouper -from wilds.common.metrics.eval_metric import Accuracy - -import IPython +from wilds.common.metrics.all_metrics import Accuracy class EncodeTFBSDataset(WILDSDataset): """ From 0a37f69881bba7b9cb494367ea106a51acf52c20 Mon Sep 17 00:00:00 2001 From: aikanor Date: Tue, 9 Feb 2021 09:34:18 -0800 Subject: [PATCH 057/244] integration 7/ --- examples/configs/datasets.py | 24 +- sandbox_data.ipynb | 24 +- sbox_run_expt.ipynb | 447 --------------------------- wilds/datasets/encodetfbs_dataset.py | 2 +- 4 files changed, 37 insertions(+), 460 deletions(-) diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index cd2d1d6f..ba644aab 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -57,8 +57,6 @@ 'weight_decay': 0.01, 'n_epochs': 5, 'n_groups_per_batch': 2, - 'irm_lambda': 1.0, - 'coral_penalty_weight': 0.1, 'algo_log_metric': 'accuracy', 'process_outputs_function': 'multiclass_logits_to_pred', }, @@ -105,6 +103,28 @@ }, 'process_outputs_function': 'multiclass_logits_to_pred', }, + 'encode-tfbs': { + 'split_scheme': 'official', + 'model': 'beagle', + 'model_kwargs': {'pretrained': False}, + 'train_transform': None, + 'eval_transform': None, + 'loss_function': 'cross_entropy', + 'groupby_fields': ['hospital'], + 'val_metric': 'ap', + 'val_metric_decreasing': False, + 'optimizer': 'Adam', + # 'optimizer_kwargs': { }, + 'scheduler': None, + 'batch_size': 128, + 'lr': 0.001, + 'weight_decay': 0.01, + 'n_epochs': 1, + 'n_groups_per_batch': 2, + # 'irm_lambda': 1.0, + # 'coral_penalty_weight': 0.1, + # 'algo_log_metric': 'accuracy', + }, 'fmow': { 'split_scheme': 'official', 'dataset_kwargs': { diff --git a/sandbox_data.ipynb b/sandbox_data.ipynb index 0a9806a6..4203968d 100644 --- a/sandbox_data.ipynb +++ b/sandbox_data.ipynb @@ -25,11 +25,15 @@ ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": {}, - "outputs": [], - "source": [] + "source": [ + "# TODOs\n", + "\n", + "- change sequence length of model\n", + " - examples/configs/model.py\n", + " - examples/models/CNN_genome.py" + ] }, { "cell_type": "markdown", @@ -590,23 +594,23 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 2", + "display_name": "Python 3", "language": "python", - "name": "python2" + "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", - "version": 2 + "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.13" + "pygments_lexer": "ipython3", + "version": "3.8.5" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/sbox_run_expt.ipynb b/sbox_run_expt.ipynb index 612397ce..6d1a135a 100644 --- a/sbox_run_expt.ipynb +++ b/sbox_run_expt.ipynb @@ -153,167 +153,6 @@ " print(ct, time.time() - itime)" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class Beagle2(nn.Module):\n", - " \"\"\"\n", - " Neural net models over genomic sequence.\n", - " Input:\n", - " - sequence_length: int (default 1000) \n", - " - Shape: (N, 5, sequence_length, 1) with batch size N.\n", - " \n", - " Output:\n", - " - prediction (Tensor): float torch tensor of shape (N, )\n", - " \n", - " TODO: Finish docstring.\n", - " \"\"\"\n", - " def __init__(self):\n", - " \"\"\"\n", - " Parameters\n", - " ----------\n", - " sequence_length : int\n", - " n_genomic_features : int\n", - " \"\"\"\n", - " super(Beagle2, self).__init__()\n", - "\n", - " self.dropout = 0.3\n", - " self.num_cell_types = 1\n", - " self.conv1 = nn.Conv2d(5, 300, (19, 1), stride = (1, 1), padding=(9,0))\n", - " self.conv2 = nn.Conv2d(300, 200, (11, 1), stride = (1, 1), padding = (5,0))\n", - " self.conv3 = nn.Conv2d(200, 200, (7, 1), stride = (1, 1), padding = (4,0))\n", - " self.bn1 = nn.BatchNorm2d(300)\n", - " self.bn2 = nn.BatchNorm2d(200)\n", - " self.bn3 = nn.BatchNorm2d(200)\n", - " self.maxpool1 = nn.MaxPool2d((3, 1))\n", - " self.maxpool2 = nn.MaxPool2d((4, 1))\n", - " self.maxpool3 = nn.MaxPool2d((4, 1))\n", - "\n", - " self.fc1 = nn.Linear(4200, 1000)\n", - " self.bn4 = nn.BatchNorm1d(1000)\n", - "\n", - " self.fc2 = nn.Linear(1000, 1000)\n", - " self.bn5 = nn.BatchNorm1d(1000)\n", - "\n", - " self.fc3 = nn.Linear(1000, self.num_cell_types)\n", - "\n", - " def forward(self, s):\n", - " s = s.permute(0, 2, 1).contiguous() # batch_size x 4 x 1000\n", - " s = s.view(-1, 5, 1000, 1) # batch_size x 4 x 1000 x 1 [4 channels]\n", - " s = self.maxpool1(F.relu(self.bn1(self.conv1(s)))) # batch_size x 300 x 333 x 1\n", - " s = self.maxpool2(F.relu(self.bn2(self.conv2(s)))) # batch_size x 200 x 83 x 1\n", - " s = self.maxpool3(F.relu(self.bn3(self.conv3(s)))) # batch_size x 200 x 21 x 1\n", - " s = s.view(-1, 4200)\n", - " conv_out = s\n", - "\n", - " s = F.dropout(F.relu(self.bn4(self.fc1(s))), p=self.dropout, training=self.training) # batch_size x 1000\n", - " #s = F.dropout(F.relu(self.bn5(self.fc2(s))), p=self.dropout, training=self.training) # batch_size x 1000\n", - " \n", - " \n", - " s = self.fc3(s)\n", - "\n", - " return s, conv_out\n", - "\n", - "\n", - "class DanQ(nn.Module):\n", - " def __init__(self, sequence_length, n_genomic_features):\n", - " \"\"\"\n", - " Parameters\n", - " ----------\n", - " sequence_length : int\n", - " Input sequence length\n", - " n_genomic_features : int\n", - " Total number of features to predict\n", - " \"\"\"\n", - " super(DanQ, self).__init__()\n", - " self.nnet = nn.Sequential(\n", - " nn.Conv1d(4, 320, kernel_size=26),\n", - " nn.ReLU(inplace=True),\n", - " nn.MaxPool1d(\n", - " kernel_size=13, stride=13),\n", - " nn.Dropout(0.2))\n", - "\n", - " self.bdlstm = nn.Sequential(\n", - " nn.LSTM(\n", - " 320, 320, num_layers=1, batch_first=True, bidirectional=True))\n", - "\n", - " self._n_channels = math.floor(\n", - " (sequence_length - 25) / 13)\n", - " self.classifier = nn.Sequential(\n", - " nn.Dropout(0.5),\n", - " nn.Linear(self._n_channels * 640, 925),\n", - " nn.ReLU(inplace=True),\n", - " nn.Linear(925, n_genomic_features),\n", - " nn.Sigmoid())\n", - "\n", - " def forward(self, x):\n", - " \"\"\"Forward propagation of a batch.\n", - " \"\"\"\n", - " out = self.nnet(x)\n", - " reshape_out = out.transpose(0, 1).transpose(0, 2)\n", - " out, _ = self.bdlstm(reshape_out)\n", - " out = out.transpose(0, 1)\n", - " reshape_out = out.contiguous().view(\n", - " out.size(0), 640 * self._n_channels)\n", - " predict = self.classifier(reshape_out)\n", - " return predict\n", - "\n", - "\n", - "class DeepSEA(nn.Module):\n", - " def __init__(self, sequence_length, n_genomic_features):\n", - " \"\"\"\n", - " Parameters\n", - " ----------\n", - " sequence_length : int\n", - " n_genomic_features : int\n", - " \"\"\"\n", - " super(DeepSEA, self).__init__()\n", - " conv_kernel_size = 8\n", - " pool_kernel_size = 4\n", - "\n", - " self.conv_net = nn.Sequential(\n", - " nn.Conv1d(4, 320, kernel_size=conv_kernel_size),\n", - " nn.ReLU(inplace=True),\n", - " nn.MaxPool1d(\n", - " kernel_size=pool_kernel_size, stride=pool_kernel_size),\n", - " nn.Dropout(p=0.2),\n", - "\n", - " nn.Conv1d(320, 480, kernel_size=conv_kernel_size),\n", - " nn.ReLU(inplace=True),\n", - " nn.MaxPool1d(\n", - " kernel_size=pool_kernel_size, stride=pool_kernel_size),\n", - " nn.Dropout(p=0.2),\n", - "\n", - " nn.Conv1d(480, 960, kernel_size=conv_kernel_size),\n", - " nn.ReLU(inplace=True),\n", - " nn.Dropout(p=0.5))\n", - "\n", - " reduce_by = conv_kernel_size - 1\n", - " pool_kernel_size = float(pool_kernel_size)\n", - " self.n_channels = int(\n", - " np.floor(\n", - " (np.floor(\n", - " (sequence_length - reduce_by) / pool_kernel_size)\n", - " - reduce_by) / pool_kernel_size)\n", - " - reduce_by)\n", - " self.classifier = nn.Sequential(\n", - " nn.Linear(960 * self.n_channels, n_genomic_features),\n", - " nn.ReLU(inplace=True),\n", - " nn.Linear(n_genomic_features, n_genomic_features),\n", - " nn.Sigmoid())\n", - "\n", - " def forward(self, x):\n", - " \"\"\"Forward propagation of a batch.\n", - " \"\"\"\n", - " out = self.conv_net(x)\n", - " reshape_out = out.view(out.size(0), 960 * self.n_channels)\n", - " predict = self.classifier(reshape_out)\n", - " return predict" - ] - }, { "cell_type": "code", "execution_count": 78, @@ -720,292 +559,6 @@ "source": [ "from examples.models.model_attributes import model_attributes" ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "ename": "ModuleNotFoundError", - "evalue": "No module named 'utils'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel_attributes\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mmodel_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mset_seed\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mLogger\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mCSVBatchLogger\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlog_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mParseKwargs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mload\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 21\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 22\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mdataset_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0moptimizer_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/dr_benchmark/examples/train.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0msave\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mIPython\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'utils'" - ] - } - ], - "source": [ - "import os, csv\n", - "import time\n", - "import argparse\n", - "import IPython\n", - "import pandas as pd\n", - "import torch\n", - "import torch.nn as nn\n", - "import torchvision\n", - "import sys\n", - "from collections import defaultdict\n", - "\n", - "# TODO: Replace this once we make wilds into an installed package\n", - "sys.path.insert(1, os.path.join(sys.path[0], '..'))\n", - "\n", - "from wilds.common.data_loaders import get_train_loader, get_eval_loader\n", - "from wilds.common.grouper import CombinatorialGrouper\n", - "from wilds.common.utils import get_counts\n", - "\n", - "from examples.models.model_attributes import model_attributes\n", - "from examples.utils import set_seed, Logger, CSVBatchLogger, log_args, ParseKwargs, load\n", - "from examples.train import train\n", - "from examples.data import dataset_attributes\n", - "from examples.optimizer import optimizer_attributes\n", - "from examples.scheduler import scheduler_attributes\n", - "from examples.loss import losses\n", - "from examples.utils import log_group_data\n", - "from examples.algorithms.constructors import algorithm_constructors\n", - "\n", - "\n", - "def initialize_algorithm(args, datasets, train_grouper):\n", - " train_dataset = datasets['train']['dataset']\n", - " train_loader = datasets['train']['loader']\n", - "\n", - " # Configure the final layer of the networks used\n", - " # The code below are defaults. Edit this if you need special config for your model.\n", - " if (train_dataset.is_classification) and (train_dataset.y_size == 1):\n", - " # For single-task classification, we have one output per class\n", - " d_out = train_dataset.n_classes\n", - " elif (not train_dataset.is_classification):\n", - " # For regression, we have one output per target dimension\n", - " d_out = train_dataset.y_size\n", - " else:\n", - " # TODO: Handle dataset-specific multi-task stuff here, e.g., for OGB\n", - " pass\n", - "\n", - " # Sanity checking input args\n", - " if args.algorithm == 'groupDRO':\n", - " assert args.train_loader_kwargs['uniform_over_groups']\n", - " elif args.algorithm in ['deepCORAL', 'IRM']:\n", - " assert args.train_loader == 'group'\n", - " assert args.train_loader_kwargs['uniform_over_groups']\n", - " assert args.train_loader_kwargs['distinct_groups']\n", - "\n", - " # Other config\n", - " n_train_steps = len(train_loader) * args.n_epochs\n", - " prediction_fn = dataset_attributes[args.dataset]['prediction_fn']\n", - " loss = losses[args.loss_function]\n", - " metric_constructor = dataset_attributes[args.dataset]['metric']\n", - " train_g = train_grouper.metadata_to_group(train_dataset.metadata_array)\n", - " is_group_in_train = get_counts(train_g, train_grouper.n_groups) > 0\n", - " algorithm_constructor = algorithm_constructors[args.algorithm]\n", - " algorithm = algorithm_constructor(\n", - " args=args,\n", - " d_out=d_out,\n", - " grouper=train_grouper,\n", - " prediction_fn=prediction_fn,\n", - " loss=loss,\n", - " metric_constructor=metric_constructor,\n", - " n_train_steps=n_train_steps,\n", - " is_group_in_train=is_group_in_train)\n", - " return algorithm" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "parser = argparse.ArgumentParser()\n", - "\n", - "# Dataset\n", - "parser.add_argument('-d', '--dataset', choices=dataset_attributes.keys(), required=True)\n", - "parser.add_argument('--split_scheme', default='standard',\n", - " help='Identifies how the train/val/test split is constructed. Choices are dataset-specific.')\n", - "parser.add_argument('--root_dir', default=None, required=True,\n", - " help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).')\n", - "parser.add_argument('--download', default=False, action='store_true',\n", - " help='If true, tries to downloads the dataset if it does not exist in root_dir.')\n", - "parser.add_argument('--frac', type=float, default=1.0,\n", - " help='Convenience parameter that scales all dataset splits down to the specified fraction, for development purposes.')\n", - "\n", - "# Loaders\n", - "parser.add_argument('--train_loader', choices=['standard', 'group'], default='standard')\n", - "parser.add_argument('--train_loader_kwargs', nargs='*', action=ParseKwargs, default={})\n", - "parser.add_argument('--eval_loader', choices=['standard', 'group'], default='standard')\n", - "parser.add_argument('--batch_size', type=int, default=32)\n", - "\n", - "# Model\n", - "parser.add_argument(\n", - " '--model',\n", - " choices=model_attributes.keys(),\n", - " default='resnet50')\n", - "parser.add_argument('--model_kwargs', nargs='*', action=ParseKwargs, default={},\n", - " help='keyword arguments for model initialization passed as key1=value1 key2=value2')\n", - "parser.add_argument('--train_from_scratch', action='store_true', default=False)\n", - "\n", - "# Algorithm and objective\n", - "parser.add_argument('--algorithm', required=True, choices=algorithm_constructors.keys())\n", - "parser.add_argument('--algorithm_kwargs', nargs='*', action=ParseKwargs, default={})\n", - "parser.add_argument('--groupby_fields', nargs='+', default=None)\n", - "parser.add_argument('--loss_function', required=True, choices = losses.keys()) #TODO: make default\n", - "parser.add_argument('--val_metric', default=None)\n", - "\n", - "# Optimization\n", - "parser.add_argument('--n_epochs', type=int, default=4)\n", - "parser.add_argument('--optimizer', default=None, choices=optimizer_attributes.keys())\n", - "parser.add_argument('--lr', type=float, required=True)\n", - "parser.add_argument('--weight_decay', type=float, required=True)\n", - "parser.add_argument('--optimizer_kwargs', nargs='*', action=ParseKwargs, default={})\n", - "parser.add_argument('--scheduler', default=None, choices=scheduler_attributes.keys())\n", - "parser.add_argument('--scheduler_kwargs', nargs='*', action=ParseKwargs, default={})\n", - "parser.add_argument('--scheduler_metric_split', choices=['train', 'val'], default='val')\n", - "parser.add_argument('--scheduler_metric_name')\n", - "\n", - "# Evaluation\n", - "parser.add_argument('--evaluate_all_splits', action='store_true', default=False)\n", - "parser.add_argument('--additional_eval_splits', nargs='+', default=[])\n", - "\n", - "# Misc\n", - "parser.add_argument('--device', default='cuda')\n", - "parser.add_argument('--seed', type=int, default=0)\n", - "parser.add_argument('--log_dir', default='./logs')\n", - "parser.add_argument('--log_every', default=50, type=int)\n", - "parser.add_argument('--save_step', type=int, default=None)\n", - "parser.add_argument('--save_best', action='store_true', default=False)\n", - "parser.add_argument('--save_last', action='store_true', default=False)\n", - "parser.add_argument('--save_outputs', action='store_true', default=False)\n", - "parser.add_argument('--no_group_logging', action='store_true', default=False)\n", - "\n", - "parser.add_argument('--resume', default=False, action='store_true')\n", - "\n", - "args = parser.parse_args()\n", - "\n", - "# Set defaults\n", - "if args.groupby_fields is None:\n", - " args.no_group_logging = True\n", - "if args.val_metric is None:\n", - " args.val_metric = dataset_attributes[args.dataset]['val_metric']\n", - "\n", - "## Initialize logs\n", - "if os.path.exists(args.log_dir) and args.resume:\n", - " resume=True\n", - " mode='a'\n", - "else:\n", - " resume=False\n", - " mode='w'\n", - "if not os.path.exists(args.log_dir):\n", - " os.makedirs(args.log_dir)\n", - "logger = Logger(os.path.join(args.log_dir, 'log.txt'), mode)\n", - "\n", - "# Record args\n", - "log_args(args, logger)\n", - "\n", - "# Set random seed\n", - "set_seed(args.seed)\n", - "\n", - "# Data\n", - "full_dataset = dataset_attributes[args.dataset]['constructor'](\n", - " root_dir=args.root_dir,\n", - " download=args.download,\n", - " split_scheme=args.split_scheme)\n", - "\n", - "# To implement data augmentation (i.e., have different transforms\n", - "# at training time vs. test time), modify these two lines:\n", - "train_transform = dataset_attributes[args.dataset]['transform'](args.model)\n", - "eval_transform = dataset_attributes[args.dataset]['transform'](args.model)\n", - "\n", - "train_grouper = CombinatorialGrouper(\n", - " dataset=full_dataset,\n", - " groupby_fields=args.groupby_fields)\n", - "\n", - "datasets = defaultdict(dict)\n", - "for split in full_dataset.split_dict.keys():\n", - " if split=='train':\n", - " transform = train_transform\n", - " verbose = True\n", - " elif split == 'val':\n", - " transform = eval_transform\n", - " verbose = True\n", - " else:\n", - " transform = eval_transform\n", - " verbose = False\n", - " # Get subset\n", - " datasets[split]['dataset'] = full_dataset.get_subset(\n", - " split,\n", - " frac=args.frac,\n", - " transform=transform)\n", - "\n", - " # Get loader\n", - " shared_loader_kwargs = {\n", - " 'num_workers': 4,\n", - " 'pin_memory': True,\n", - " 'batch_size': args.batch_size,\n", - " 'collate_fn': dataset_attributes[args.dataset]['collate']\n", - " }\n", - "\n", - " if split == 'train':\n", - " datasets[split]['loader'] = get_train_loader(\n", - " loader=args.train_loader,\n", - " dataset=datasets[split]['dataset'],\n", - " grouper=train_grouper,\n", - " train_loader_kwargs=args.train_loader_kwargs,\n", - " **shared_loader_kwargs)\n", - " else:\n", - " datasets[split]['loader'] = get_eval_loader(\n", - " loader=args.eval_loader,\n", - " dataset=datasets[split]['dataset'],\n", - " grouper=train_grouper,\n", - " **shared_loader_kwargs)\n", - "\n", - " # Set fields\n", - " datasets[split]['split'] = split\n", - " datasets[split]['name'] = full_dataset.split_names[split]\n", - " datasets[split]['verbose'] = verbose\n", - " # Loggers\n", - " datasets[split]['eval_logger'] = CSVBatchLogger(\n", - " os.path.join(args.log_dir, f'{split}_eval.csv'), mode=mode)\n", - " datasets[split]['algo_logger'] = CSVBatchLogger(\n", - " os.path.join(args.log_dir, f'{split}_algo.csv'), mode=mode)\n", - "\n", - "# Logging dataset info\n", - "if args.no_group_logging and full_dataset.is_classification and full_dataset.y_size==1:\n", - " log_grouper = CombinatorialGrouper(\n", - " dataset=full_dataset,\n", - " groupby_fields=['y'])\n", - "elif args.no_group_logging:\n", - " log_grouper = None\n", - "else:\n", - " log_grouper = train_grouper\n", - "log_group_data(args, datasets, log_grouper, logger)\n", - "\n", - "## Initialize algorithm\n", - "algorithm = initialize_algorithm(args, datasets, train_grouper)\n", - "\n", - "## Load saved results if resuming\n", - "if resume:\n", - " save_path = os.path.join(args.log_dir, 'last_model.pth')\n", - " prev_epoch, best_val_metric = load(algorithm, save_path)\n", - " epoch_offset = prev_epoch + 1\n", - "else:\n", - " epoch_offset=0\n", - " best_val_metric=None\n", - "\n", - "train(algorithm,\n", - " datasets,\n", - " logger,\n", - " args,\n", - " epoch_offset=epoch_offset,\n", - " best_val_metric=best_val_metric)\n", - "\n", - "logger.close()\n", - "for split in datasets:\n", - " datasets[split]['eval_logger'].close()\n", - " datasets[split]['algo_logger'].close()" - ] } ], "metadata": { diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index 062e468a..e55fba11 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -127,7 +127,7 @@ def get_input(self, idx): Computes this from: (1) sequence features in self._seq_bp (2) DNase features in self._dnase_allcelltypes - (3) Metadata for the index (location along the genome with 1kb window width) + (3) Metadata for the index (location along the genome with 200bp window width) """ this_metadata = self._metadata_df.iloc[idx, :] flank_size = 400 From 01b5f5cb2efeafa8ef7320d05b297757db81185c Mon Sep 17 00:00:00 2001 From: aikanor Date: Tue, 9 Feb 2021 10:53:53 -0800 Subject: [PATCH 058/244] integration 8/ --- wilds/datasets/encodetfbs_dataset.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index e55fba11..d26d052a 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -8,8 +8,20 @@ class EncodeTFBSDataset(WILDSDataset): """ - EncodeTFBS dataset - Website: https://www.synapse.org/#!Synapse:syn6131484 + ENCODE-DREAM-wilds dataset of transcription factor binding sites. + This is a subset of the dataset from the ENCODE-DREAM in vivo Transcription Factor Binding Site Prediction Challenge. + + Input (x): + 1000-base-pair regions of sequence with a quantified chromatin accessibility readout. + + Label (y): + y is binary. It is 1 if the central 200bp region is bound by the transcription factor MAX, and 0 otherwise. + + Metadata: + Each sequence is annotated with the celltype of origin (a string) and the chromosome of origin (a string). + + Website: + https://www.synapse.org/#!Synapse:syn6131484 """ def __init__(self, root_dir, download, split_scheme): @@ -19,6 +31,7 @@ def __init__(self, root_dir, download, split_scheme): self._y_size = 1 self._n_classes = 2 + # self._tr_chrs = ['chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX'] self._tr_chrs = ['chr2', 'chr9', 'chr11'] self._te_chrs = ['chr1', 'chr8', 'chr21'] self._transcription_factor = 'MAX' From 2b8dc7dbfcd1b4f0afaffc980a2547fe176fa964 Mon Sep 17 00:00:00 2001 From: aikanor Date: Tue, 9 Feb 2021 11:37:26 -0800 Subject: [PATCH 059/244] integration 9/ --- sandbox_data.ipynb | 4 +++- wilds/common/metrics/all_metrics.py | 19 +++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/sandbox_data.ipynb b/sandbox_data.ipynb index 4203968d..681c5ec2 100644 --- a/sandbox_data.ipynb +++ b/sandbox_data.ipynb @@ -8,7 +8,7 @@ " - run_expt.py\n", " - configs\n", " - [x] supported.py\n", - " - [ ] model.py\n", + " - [x] model.py\n", " - [ ] datasets.py\n", " - models\n", " - [x] CNN_genome.py\n", @@ -30,6 +30,8 @@ "source": [ "# TODOs\n", "\n", + "- change evaluation metric\n", + "\n", "- change sequence length of model\n", " - examples/configs/model.py\n", " - examples/models/CNN_genome.py" diff --git a/wilds/common/metrics/all_metrics.py b/wilds/common/metrics/all_metrics.py index 0f5d7eb1..6a680467 100644 --- a/wilds/common/metrics/all_metrics.py +++ b/wilds/common/metrics/all_metrics.py @@ -77,6 +77,25 @@ def _compute(self, y_pred, y_true): def worst(self, metrics): return minimum(metrics) +class AveragePrecision(Metric): + def __init__(self, prediction_fn=logits_to_pred, name=None, average='weighted'): + self.prediction_fn = prediction_fn + if name is None: + name = f'avgprec' + if average is not None: + name+=f'-{average}' + self.average = average + super().__init__(name=name) + + def _compute(self, y_pred, y_true): + if self.prediction_fn is not None: + y_pred = self.prediction_fn(y_pred) + score = sklearn.metrics.average_precision_score(y_true, y_pred, average=self.average, labels=torch.unique(y_true)) + return torch.tensor(score) + + def worst(self, metrics): + return minimum(metrics) + class F1(Metric): def __init__(self, prediction_fn=None, name=None, average='binary'): self.prediction_fn = prediction_fn From 74fe367df4d04139a28e39eb8eb4c1b98129d1d7 Mon Sep 17 00:00:00 2001 From: aikanor Date: Tue, 9 Feb 2021 11:50:07 -0800 Subject: [PATCH 060/244] using avg accuracy across balanced splits for now, until avg precision is tested --- examples/configs/datasets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index ba644aab..b3cb0a56 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -111,7 +111,7 @@ 'eval_transform': None, 'loss_function': 'cross_entropy', 'groupby_fields': ['hospital'], - 'val_metric': 'ap', + 'val_metric': 'acc_avg', 'val_metric_decreasing': False, 'optimizer': 'Adam', # 'optimizer_kwargs': { }, @@ -123,7 +123,7 @@ 'n_groups_per_batch': 2, # 'irm_lambda': 1.0, # 'coral_penalty_weight': 0.1, - # 'algo_log_metric': 'accuracy', + 'algo_log_metric': 'accuracy' }, 'fmow': { 'split_scheme': 'official', From c8132e136100f88d573848c25dba8b86f4cce537 Mon Sep 17 00:00:00 2001 From: aikanor Date: Tue, 9 Feb 2021 11:56:06 -0800 Subject: [PATCH 061/244] fix --- examples/configs/datasets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index b3cb0a56..9d295619 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -110,7 +110,7 @@ 'train_transform': None, 'eval_transform': None, 'loss_function': 'cross_entropy', - 'groupby_fields': ['hospital'], + 'groupby_fields': ['celltype'], 'val_metric': 'acc_avg', 'val_metric_decreasing': False, 'optimizer': 'Adam', @@ -121,9 +121,9 @@ 'weight_decay': 0.01, 'n_epochs': 1, 'n_groups_per_batch': 2, + 'algo_log_metric': 'accuracy', # 'irm_lambda': 1.0, # 'coral_penalty_weight': 0.1, - 'algo_log_metric': 'accuracy' }, 'fmow': { 'split_scheme': 'official', From f64758761c05c6cacca6f2f6d488f4d636bfc4be Mon Sep 17 00:00:00 2001 From: aikanor Date: Tue, 9 Feb 2021 15:56:50 -0800 Subject: [PATCH 062/244] integration 10/ --- .../encode-tfbs/prep_accessibility.py | 4 +--- examples/configs/datasets.py | 24 ++----------------- wilds/datasets/encodetfbs_dataset.py | 11 +++++---- 3 files changed, 9 insertions(+), 30 deletions(-) diff --git a/dataset_preprocessing/encode-tfbs/prep_accessibility.py b/dataset_preprocessing/encode-tfbs/prep_accessibility.py index 7342f797..31bf872c 100644 --- a/dataset_preprocessing/encode-tfbs/prep_accessibility.py +++ b/dataset_preprocessing/encode-tfbs/prep_accessibility.py @@ -1,8 +1,6 @@ -import numpy, pandas +import numpy as np import pyBigWig -from tqdm import tqdm - # Human chromosome names chr_IDs = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX'] diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index 9d295619..cd2d1d6f 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -57,6 +57,8 @@ 'weight_decay': 0.01, 'n_epochs': 5, 'n_groups_per_batch': 2, + 'irm_lambda': 1.0, + 'coral_penalty_weight': 0.1, 'algo_log_metric': 'accuracy', 'process_outputs_function': 'multiclass_logits_to_pred', }, @@ -103,28 +105,6 @@ }, 'process_outputs_function': 'multiclass_logits_to_pred', }, - 'encode-tfbs': { - 'split_scheme': 'official', - 'model': 'beagle', - 'model_kwargs': {'pretrained': False}, - 'train_transform': None, - 'eval_transform': None, - 'loss_function': 'cross_entropy', - 'groupby_fields': ['celltype'], - 'val_metric': 'acc_avg', - 'val_metric_decreasing': False, - 'optimizer': 'Adam', - # 'optimizer_kwargs': { }, - 'scheduler': None, - 'batch_size': 128, - 'lr': 0.001, - 'weight_decay': 0.01, - 'n_epochs': 1, - 'n_groups_per_batch': 2, - 'algo_log_metric': 'accuracy', - # 'irm_lambda': 1.0, - # 'coral_penalty_weight': 0.1, - }, 'fmow': { 'split_scheme': 'official', 'dataset_kwargs': { diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index d26d052a..23f0b1d7 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -25,7 +25,8 @@ class EncodeTFBSDataset(WILDSDataset): """ def __init__(self, root_dir, download, split_scheme): - self._dataset_name = 'encodeTFBS' + self._dataset_name = 'encode-tfbs' + self._version = '1.0' self._download_url = 'https://worksheets.codalab.org/rest/bundles/0x8b3255e21e164cd98d3aeec09cd0bc26/contents/blob/' self._data_dir = self.initialize_data_dir(root_dir, download) self._y_size = 1 @@ -55,10 +56,10 @@ def __init__(self, root_dir, download, split_scheme): self._dnase_allcelltypes = {} for ct in self._all_celltypes: dnase_filename = os.path.join(self._data_dir, '{}_dnase.npz'.format(ct)) - dnase_npz_file = np.load(dnase_filename) + dnase_npz_contents = np.load(dnase_filename) self._dnase_allcelltypes[ct] = {} - for chrom in seq_bp: - self._dnase_allcelltypes[ct][chrom] = dnase_npz_file[chrom] + for chrom in self._seq_bp: + self._dnase_allcelltypes[ct][chrom] = dnase_npz_contents[chrom] # Read in metadata dataframe from training+validation data train_chr = pd.read_csv(os.path.join(self._data_dir, 'labels/{}.train.labels.tsv.gz'.format(self._transcription_factor)), sep='\t') @@ -130,7 +131,7 @@ def __init__(self, root_dir, download, split_scheme): self._eval_grouper = CombinatorialGrouper( dataset=self, groupby_fields=['celltype']) - self._metric = Auprc() + self._metric = Accuracy() super().__init__(root_dir, download, split_scheme) From de20ef12697f865bcd1a05fe1c50c11bc618eb1a Mon Sep 17 00:00:00 2001 From: aikanor Date: Tue, 9 Feb 2021 19:00:21 -0800 Subject: [PATCH 063/244] integration 11/ --- .../encode-tfbs/prep_sequence.py | 214 ++++++++---------- examples/configs/datasets.py | 22 ++ 2 files changed, 116 insertions(+), 120 deletions(-) diff --git a/dataset_preprocessing/encode-tfbs/prep_sequence.py b/dataset_preprocessing/encode-tfbs/prep_sequence.py index 7f396d9f..7d6ede23 100644 --- a/dataset_preprocessing/encode-tfbs/prep_sequence.py +++ b/dataset_preprocessing/encode-tfbs/prep_sequence.py @@ -1,130 +1,104 @@ import argparse, time -import numpy, pandas +import numpy as np from tqdm import tqdm # Human chromosome names chr_IDs = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX'] -def one_hot_encode(sequence, ignore='N', alphabet=None, dtype='int8', - verbose=False, **kwargs): - """Converts a string or list of characters into a one-hot encoding. - This function will take in either a string or a list and convert it into a - one-hot encoding. If the input is a string, each character is assumed to be - a different symbol, e.g. 'ACGT' is assumed to be a sequence of four - characters. If the input is a list, the elements can be any size. - Although this function will be used here primarily to convert nucleotide - sequences into one-hot encoding with an alphabet of size 4, in principle - this function can be used for any types of sequences. - Parameters - ---------- - sequence : str or list - The sequence to convert to a one-hot encoding. - ignore : str, optional - A character to indicate setting nothing to 1 for that row, keeping the - encoding entirely 0's for that row. In the context of genomics, this is - the N character. Default is 'N'. - alphabet : set or tuple or list, optional - A pre-defined alphabet. If None is passed in, the alphabet will be - determined from the sequence, but this may be time consuming for - large sequences. Default is None. - dtype : str or numpy.dtype, optional - The data type of the returned encoding. Default is int8. - verbose : bool or str, optional - Whether to display a progress bar. If a string is passed in, use as the - name of the progressbar. Default is False. - kwargs : arguments - Arguments to be passed into tqdm. Default is None. - Returns - ------- - ohe : numpy.ndarray - A binary matrix of shape (alphabet_size, sequence_length) where - alphabet_size is the number of unique elements in the sequence and - sequence_length is the length of the input sequence. - """ - - name = None if verbose in (True, False) else verbose - d = verbose is False - - if isinstance(sequence, str): - sequence = list(sequence) - - alphabet = alphabet or numpy.unique(sequence) - alphabet = [char for char in alphabet if char != ignore] - alphabet_lookup = {char: i for i, char in enumerate(alphabet)} - - ohe = numpy.zeros((len(sequence), len(alphabet)), dtype=dtype) - for i, char in tqdm(enumerate(sequence), disable=d, desc=name, **kwargs): - if char != ignore: - idx = alphabet_lookup[char] - ohe[i, idx] = 1 - - return ohe - - -def read_fasta(filename, include_chroms=None, exclude_chroms=None, - ignore='N', alphabet=['A', 'C', 'G', 'T', 'N'], verbose=True): - """Read in a FASTA file and output a dictionary of sequences. - This function will take in the path to a FASTA-formatted file and output - a string containing the sequence for each chromosome. Optionally, - the user can specify a set of chromosomes to include or exclude from - the returned dictionary. - Parameters - ---------- - filename : str - The path to the FASTA-formatted file to open. - include_chroms : set or tuple or list, optional - The exact names of chromosomes in the FASTA file to include, excluding - all others. If None, include all chromosomes (except those specified by - exclude_chroms). Default is None. - exclude_chroms : set or tuple or list, optional - The exact names of chromosomes in the FASTA file to exclude, including - all others. If None, include all chromosomes (or the set specified by - include_chroms). Default is None. - ignore : str, optional - A character to indicate setting nothing to 1 for that row, keeping the - encoding entirely 0's for that row. In the context of genomics, this is - the N character. Default is 'N'. - alphabet : set or tuple or list, optional - A pre-defined alphabet. If None is passed in, the alphabet will be - determined from the sequence, but this may be time consuming for - large sequences. Must include the ignore character. Default is - ['A', 'C', 'G', 'T', 'N']. - verbose : bool or str, optional - Whether to display a progress bar. If a string is passed in, use as the - name of the progressbar. Default is False. - Returns - ------- - chroms : dict - A dictionary of strings where the keys are the names of the - chromosomes (exact strings from the header lines in the FASTA file) - and the values are the strings encoded there. - """ - - sequences = {} - name, sequence = None, None - skip_chrom = False - - with open(filename, "r") as infile: - for line in tqdm(infile, disable=not verbose): - if line.startswith(">"): - if name is not None and skip_chrom is False: - sequences[name] = ''.join(sequence) - - sequence = [] - name = line[1:].strip("\n") - if include_chroms is not None and name not in include_chroms: - skip_chrom = True - elif exclude_chroms is not None and name in exclude_chroms: - skip_chrom = True - else: - skip_chrom = False - - else: - if skip_chrom == False: - sequence.append(line.rstrip("\n").upper()) - - return sequences +def one_hot_encode(sequence, ignore='N', alphabet=None, dtype='int8', verbose=False, **kwargs): + """ + Converts a string or list of characters into a one-hot encoding. + This function will take in either a string or a list and convert it into a one-hot encoding. If the input is a string, each character is assumed to be a different symbol, e.g. 'ACGT' is assumed to be a sequence of four characters. If the input is a list, the elements can be any size. + Although this function will be used here primarily to convert nucleotide sequences into one-hot encoding with an alphabet of size 4, in principle this function can be used for any types of sequences. + + Parameters + ---------- + sequence : str or list + The sequence to convert to a one-hot encoding. + ignore : str, optional + A character to indicate setting nothing to 1 for that row, keeping the encoding entirely 0's for that row. In the context of genomics, this is the N character. Default is 'N'. + alphabet : set or tuple or list, optional + A pre-defined alphabet. If None is passed in, the alphabet will be determined from the sequence, but this may be time consuming for large sequences. Default is None. + dtype : str or numpy.dtype, optional + The data type of the returned encoding. Default is int8. + verbose : bool or str, optional + Whether to display a progress bar. If a string is passed in, use as the name of the progressbar. Default is False. + kwargs : arguments + Arguments to be passed into tqdm. Default is None. + + Returns + ------- + ohe : numpy.ndarray + A binary matrix of shape (alphabet_size, sequence_length) where alphabet_size is the number of unique elements in the sequence and sequence_length is the length of the input sequence. + """ + + name = None if verbose in (True, False) else verbose + d = verbose is False + + if isinstance(sequence, str): + sequence = list(sequence) + + alphabet = alphabet or np.unique(sequence) + alphabet = [char for char in alphabet if char != ignore] + alphabet_lookup = {char: i for i, char in enumerate(alphabet)} + + ohe = np.zeros((len(sequence), len(alphabet)), dtype=dtype) + for i, char in tqdm(enumerate(sequence), disable=d, desc=name, **kwargs): + if char != ignore: + idx = alphabet_lookup[char] + ohe[i, idx] = 1 + + return ohe + + +def read_fasta(filename, include_chroms=None, exclude_chroms=None, ignore='N', alphabet=['A', 'C', 'G', 'T', 'N'], verbose=True): + """ + Read in a FASTA file and output a dictionary of sequences. + This function will take in the path to a FASTA-formatted file and output a string containing the sequence for each chromosome. Optionally, the user can specify a set of chromosomes to include or exclude from the returned dictionary. + + Parameters + ---------- + filename : str + The path to the FASTA-formatted file to open. + include_chroms : set or tuple or list, optional + The exact names of chromosomes in the FASTA file to include, excluding all others. If None, include all chromosomes (except those specified by exclude_chroms). Default is None. + exclude_chroms : set or tuple or list, optional + The exact names of chromosomes in the FASTA file to exclude, including all others. If None, include all chromosomes (or the set specified by include_chroms). Default is None. + ignore : str, optional + A character to indicate setting nothing to 1 for that row, keeping the encoding entirely 0's for that row. In the context of genomics, this is the N character. Default is 'N'. + alphabet : set or tuple or list, optional + A pre-defined alphabet. If None is passed in, the alphabet will be determined from the sequence, but this may be time consuming for large sequences. Must include the ignore character. Default is ['A', 'C', 'G', 'T', 'N']. + verbose : bool or str, optional + Whether to display a progress bar. If a string is passed in, use as the name of the progressbar. Default is False. + + Returns + ------- + chroms : dict + A dictionary of strings where the keys are the names of the chromosomes (exact strings from the header lines in the FASTA file) and the values are the strings encoded there. + """ + + sequences = {} + name, sequence = None, None + skip_chrom = False + + with open(filename, "r") as infile: + for line in tqdm(infile, disable=not verbose): + if line.startswith(">"): + if name is not None and skip_chrom is False: + sequences[name] = ''.join(sequence) + sequence = [] + name = line[1:].strip("\n") + if include_chroms is not None and name not in include_chroms: + skip_chrom = True + elif exclude_chroms is not None and name in exclude_chroms: + skip_chrom = True + else: + skip_chrom = False + else: + if skip_chrom == False: + sequence.append(line.rstrip("\n").upper()) + return sequences def generate_sequence_archive(seq_path='sequence/hg19.genome.fa', output_dir): diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index cd2d1d6f..727d0e92 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -105,6 +105,28 @@ }, 'process_outputs_function': 'multiclass_logits_to_pred', }, + 'encode-tfbs': { + 'split_scheme': 'official', + 'model': 'beagle', + 'model_kwargs': {'pretrained': False}, + 'train_transform': None, + 'eval_transform': None, + 'loss_function': 'cross_entropy', + 'groupby_fields': ['celltype'], + 'val_metric': 'acc_avg', + 'val_metric_decreasing': False, + 'optimizer': 'Adam', + # 'optimizer_kwargs': { }, + 'scheduler': None, + 'batch_size': 128, + 'lr': 0.001, + 'weight_decay': 0.01, + 'n_epochs': 1, + 'n_groups_per_batch': 2, + 'algo_log_metric': 'accuracy', + # 'irm_lambda': 1.0, + # 'coral_penalty_weight': 0.1, + }, 'fmow': { 'split_scheme': 'official', 'dataset_kwargs': { From dbc1fa0565df9daa15ce6b815b384a29d87be72f Mon Sep 17 00:00:00 2001 From: aikanor Date: Wed, 10 Feb 2021 00:18:56 -0800 Subject: [PATCH 064/244] integration 12/ --- dataset_preprocessing/encode-tfbs/README.md | 2 +- .../encode-tfbs/prep_accessibility.py | 12 ++++++------ .../encode-tfbs/prep_datasets.ipynb | 12 ++++++------ .../encode-tfbs/prep_sequence.py | 4 +++- examples/models/CNN_genome.py | 15 +++------------ sandbox_data.ipynb | 9 ++++++--- wilds/common/metrics/all_metrics.py | 2 +- 7 files changed, 26 insertions(+), 30 deletions(-) diff --git a/dataset_preprocessing/encode-tfbs/README.md b/dataset_preprocessing/encode-tfbs/README.md index 0be5fbd6..616d4cb5 100644 --- a/dataset_preprocessing/encode-tfbs/README.md +++ b/dataset_preprocessing/encode-tfbs/README.md @@ -5,7 +5,7 @@ #### Instructions -1. Download the human genome sequence (hg19 assembly) in FASTA format from http://hgdownload.cse.ucsc.edu/goldenpath/hg19/bigZips/hg19.fa.gz into `SEQUENCE_PATH`. +1. Download the human genome sequence (hg19 assembly) in FASTA format from http://hgdownload.cse.ucsc.edu/goldenpath/hg19/bigZips/hg19.fa.gz and extract it into `SEQUENCE_PATH`. 2. Run `python prep_sequence.py --seq_path SEQUENCE_PATH --output_dir OUTPUT_DIR` to write the fasta file found in `SEQUENCE_PATH` to a numpy array archive in `OUTPUT_DIR`. diff --git a/dataset_preprocessing/encode-tfbs/prep_accessibility.py b/dataset_preprocessing/encode-tfbs/prep_accessibility.py index 31bf872c..141981c0 100644 --- a/dataset_preprocessing/encode-tfbs/prep_accessibility.py +++ b/dataset_preprocessing/encode-tfbs/prep_accessibility.py @@ -1,24 +1,24 @@ +import argparse, time import numpy as np import pyBigWig # Human chromosome names chr_IDs = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX'] -def generate_accessibility_archives(input_dir, output_dir): +def generate_accessibility_archives(input_dir='dnase_bigwigs', output_dir='codalab_archive'): dnases = {} celltypes = ['A549', 'GM12878', 'H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562'] - for ctype in celltypes:#glob.glob('dnase_bigwigs/*'): + for ctype in celltypes: itime = time.time() - # ctype = pth.split('/')[1].split('.')[1] bw = pyBigWig.open("{}/DNASE.{}.fc.signal.bigwig".format(input_dir, ctype)) chromsizes = bw.chroms() - print(ctype, time.time() - itime) dn_dict = {} for chrom in chromsizes: #chr_IDs: x = bw.values(chrom, 0, chromsizes[chrom], numpy=True) - dn_dict[chrom] = np.nan_to_num(x).astype(np.float16) # half-precision makes things significantly smaller (less time to load) - print(chrom, time.time() - itime) + # half-precision makes things significantly smaller (less time to load) + dn_dict[chrom] = np.nan_to_num(x).astype(np.float16) + print("{}, {}. Time: {}".format(ctype, chrom, time.time() - itime)) dnases[ctype] = dn_dict for ctype in dnases: diff --git a/dataset_preprocessing/encode-tfbs/prep_datasets.ipynb b/dataset_preprocessing/encode-tfbs/prep_datasets.ipynb index 4b1fdc10..78235fd7 100644 --- a/dataset_preprocessing/encode-tfbs/prep_datasets.ipynb +++ b/dataset_preprocessing/encode-tfbs/prep_datasets.ipynb @@ -257,23 +257,23 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 2", + "display_name": "Python 3", "language": "python", - "name": "python2" + "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", - "version": 2 + "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.13" + "pygments_lexer": "ipython3", + "version": "3.8.5" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/dataset_preprocessing/encode-tfbs/prep_sequence.py b/dataset_preprocessing/encode-tfbs/prep_sequence.py index 7d6ede23..3ead9a27 100644 --- a/dataset_preprocessing/encode-tfbs/prep_sequence.py +++ b/dataset_preprocessing/encode-tfbs/prep_sequence.py @@ -3,6 +3,8 @@ from tqdm import tqdm +# Sequence preprocessing. Code adapted from Jacob Schreiber. + # Human chromosome names chr_IDs = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX'] @@ -32,7 +34,7 @@ def one_hot_encode(sequence, ignore='N', alphabet=None, dtype='int8', verbose=Fa ohe : numpy.ndarray A binary matrix of shape (alphabet_size, sequence_length) where alphabet_size is the number of unique elements in the sequence and sequence_length is the length of the input sequence. """ - + name = None if verbose in (True, False) else verbose d = verbose is False diff --git a/examples/models/CNN_genome.py b/examples/models/CNN_genome.py index 8a658eab..b0743960 100644 --- a/examples/models/CNN_genome.py +++ b/examples/models/CNN_genome.py @@ -6,23 +6,14 @@ class Beagle(nn.Module): """ - Neural net models over genomic sequence. + Neural net models over genomic sequence. Adapted from https://github.com/kundajelab/ChromDragoNN Input: - - sequence_length: int (default 1000) - - Shape: (N, 5, sequence_length, 1) with batch size N. + - s (Tensor): float torch tensor of shape (N, 5, 1000, 1) with batch size N. Output: - prediction (Tensor): float torch tensor of shape (N, ) - - TODO: Finish docstring. """ def __init__(self): - """ - Parameters - ---------- - sequence_length : int - n_genomic_features : int - """ super(Beagle, self).__init__() self.dropout = 0.3 @@ -57,6 +48,6 @@ def forward(self, s): s = F.dropout(F.relu(self.bn4(self.fc1(s))), p=self.dropout, training=self.training) # batch_size x 1000 s = F.dropout(F.relu(self.bn5(self.fc2(s))), p=self.dropout, training=self.training) # batch_size x 1000 - s = self.fc3(s) + prediction = self.fc3(s) return s#, conv_out diff --git a/sandbox_data.ipynb b/sandbox_data.ipynb index 681c5ec2..a348d6ea 100644 --- a/sandbox_data.ipynb +++ b/sandbox_data.ipynb @@ -18,7 +18,7 @@ " - [x] datasets/encodetfbs_dataset.py\n", " - common\n", " - metrics\n", - " - [ ] all_metrics.py\n", + " - [x] all_metrics.py\n", " - data_loaders.py\n", " - grouper.py\n", " - [ ] utils.py ( threshold_at_recall() )" @@ -30,9 +30,12 @@ "source": [ "# TODOs\n", "\n", - "- change evaluation metric\n", + "- [ ] change evaluation/validation metric\n", + " - \n", "\n", - "- change sequence length of model\n", + "- [ ] Citation/license for wilds/datasets/encodetfbs_dataset.py\n", + "\n", + "- (optional) change sequence length of model\n", " - examples/configs/model.py\n", " - examples/models/CNN_genome.py" ] diff --git a/wilds/common/metrics/all_metrics.py b/wilds/common/metrics/all_metrics.py index 6a680467..a491bedd 100644 --- a/wilds/common/metrics/all_metrics.py +++ b/wilds/common/metrics/all_metrics.py @@ -78,7 +78,7 @@ def worst(self, metrics): return minimum(metrics) class AveragePrecision(Metric): - def __init__(self, prediction_fn=logits_to_pred, name=None, average='weighted'): + def __init__(self, prediction_fn=logits_to_pred, name=None, average='macro'): self.prediction_fn = prediction_fn if name is None: name = f'avgprec' From 334d57545d58e3b485e3d22db403eb842d1016db Mon Sep 17 00:00:00 2001 From: aikanor Date: Wed, 10 Feb 2021 01:11:25 -0800 Subject: [PATCH 065/244] integration 13/ --- examples/models/CNN_genome.py | 2 +- sandbox_data.ipynb | 16 +++++++--------- wilds/datasets/encodetfbs_dataset.py | 23 ++++++++++++----------- 3 files changed, 20 insertions(+), 21 deletions(-) diff --git a/examples/models/CNN_genome.py b/examples/models/CNN_genome.py index b0743960..cc464ef0 100644 --- a/examples/models/CNN_genome.py +++ b/examples/models/CNN_genome.py @@ -50,4 +50,4 @@ def forward(self, s): prediction = self.fc3(s) - return s#, conv_out + return s #, conv_out diff --git a/sandbox_data.ipynb b/sandbox_data.ipynb index a348d6ea..15e35cc9 100644 --- a/sandbox_data.ipynb +++ b/sandbox_data.ipynb @@ -9,7 +9,7 @@ " - configs\n", " - [x] supported.py\n", " - [x] model.py\n", - " - [ ] datasets.py\n", + " - [x] datasets.py\n", " - models\n", " - [x] CNN_genome.py\n", " - train.py\n", @@ -21,7 +21,7 @@ " - [x] all_metrics.py\n", " - data_loaders.py\n", " - grouper.py\n", - " - [ ] utils.py ( threshold_at_recall() )" + " - [x] utils.py ( threshold_at_recall() )" ] }, { @@ -30,14 +30,12 @@ "source": [ "# TODOs\n", "\n", - "- [ ] change evaluation/validation metric\n", - " - \n", - "\n", - "- [ ] Citation/license for wilds/datasets/encodetfbs_dataset.py\n", - "\n", + "- change evaluation/validation metric\n", + " - [ ] examples/configs/datasets.py\n", + "- Citation/license for wilds/datasets/encodetfbs_dataset.py\n", "- (optional) change sequence length of model\n", - " - examples/configs/model.py\n", - " - examples/models/CNN_genome.py" + " - [ ] examples/configs/model.py\n", + " - [ ] examples/models/CNN_genome.py" ] }, { diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index 23f0b1d7..56f8c2f9 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -62,13 +62,14 @@ def __init__(self, root_dir, download, split_scheme): self._dnase_allcelltypes[ct][chrom] = dnase_npz_contents[chrom] # Read in metadata dataframe from training+validation data - train_chr = pd.read_csv(os.path.join(self._data_dir, 'labels/{}.train.labels.tsv.gz'.format(self._transcription_factor)), sep='\t') - val_chr = pd.read_csv(os.path.join(self._data_dir, 'labels/{}.val.labels.tsv.gz'.format(self._transcription_factor)), sep='\t') - training_df = train_chr[np.isin(train_chr['chr'], self._tr_chrs)] - val_df = val_chr[np.isin(val_chr['chr'], self._te_chrs)] + train_regions_labeled = pd.read_csv(os.path.join(self._data_dir, 'labels/{}.train.labels.tsv.gz'.format(self._transcription_factor)), sep='\t') + val_regions_labeled = pd.read_csv(os.path.join(self._data_dir, 'labels/{}.val.labels.tsv.gz'.format(self._transcription_factor)), sep='\t') + training_df = train_regions_labeled[np.isin(train_regions_labeled['chr'], self._tr_chrs)] + val_df = val_regions_labeled[np.isin(val_regions_labeled['chr'], self._te_chrs)] all_df = pd.concat([training_df, val_df]) - # Filter by start/stop coordinate if needed + # Filter by start/stop coordinate if needed + # (TODO: remove for final version) filter_msk = all_df['start'] >= 0 filter_msk = all_df['start']%1000 == 0 all_df = all_df[filter_msk] @@ -111,18 +112,18 @@ def __init__(self, root_dir, download, split_scheme): 'test': 'Test', 'val-ood': 'Validation (OOD)', } - train_chr_mask = np.isin(self._metadata_df['chr'], self._tr_chrs) - val_chr_mask = np.isin(self._metadata_df['chr'], self._te_chrs) + train_regions_mask = np.isin(self._metadata_df['chr'], self._tr_chrs) + val_regions_mask = np.isin(self._metadata_df['chr'], self._te_chrs) train_celltype_mask = np.isin(self._metadata_df['celltype'], self._train_celltypes) val_celltype_mask = np.isin(self._metadata_df['celltype'], self._val_celltype) test_celltype_mask = np.isin(self._metadata_df['celltype'], self._test_celltype) split_array = -1*np.ones(self._metadata_df.shape[0]).astype(int) - split_array[np.logical_and(train_chr_mask, train_celltype_mask)] = self._split_dict['train'] - split_array[np.logical_and(val_chr_mask, test_celltype_mask)] = self._split_dict['test'] + split_array[np.logical_and(train_regions_mask, train_celltype_mask)] = self._split_dict['train'] + split_array[np.logical_and(val_regions_mask, test_celltype_mask)] = self._split_dict['test'] # Validate using test chr, either using a designated validation cell line ('val-ood') or a training cell line ('val-id') - split_array[np.logical_and(val_chr_mask, val_celltype_mask)] = self._split_dict['val-ood'] - split_array[np.logical_and(val_chr_mask, train_celltype_mask)] = self._split_dict['val-id'] + split_array[np.logical_and(val_regions_mask, val_celltype_mask)] = self._split_dict['val-ood'] + split_array[np.logical_and(val_regions_mask, train_celltype_mask)] = self._split_dict['val-id'] if self._split_scheme=='standard': self._metadata_df['split'] = split_array self._split_array = split_array From c712dc240ea02823b0a0091b918141233f4a7e5c Mon Sep 17 00:00:00 2001 From: aikanor Date: Wed, 10 Feb 2021 04:03:51 -0800 Subject: [PATCH 066/244] integration 14/ --- sandbox_data.ipynb | 1 + wilds/datasets/encodetfbs_dataset.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/sandbox_data.ipynb b/sandbox_data.ipynb index 15e35cc9..c465e0ab 100644 --- a/sandbox_data.ipynb +++ b/sandbox_data.ipynb @@ -32,6 +32,7 @@ "\n", "- change evaluation/validation metric\n", " - [ ] examples/configs/datasets.py\n", + "- Add `RELEASE_v1.0.txt` to codalab archive\n", "- Citation/license for wilds/datasets/encodetfbs_dataset.py\n", "- (optional) change sequence length of model\n", " - [ ] examples/configs/model.py\n", diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index 56f8c2f9..dc76a366 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -24,7 +24,7 @@ class EncodeTFBSDataset(WILDSDataset): https://www.synapse.org/#!Synapse:syn6131484 """ - def __init__(self, root_dir, download, split_scheme): + def __init__(self, root_dir='data', download=False, split_scheme='official'): self._dataset_name = 'encode-tfbs' self._version = '1.0' self._download_url = 'https://worksheets.codalab.org/rest/bundles/0x8b3255e21e164cd98d3aeec09cd0bc26/contents/blob/' From a0c8c8fc2b101f7a66e018194529d9cf423c8dfe Mon Sep 17 00:00:00 2001 From: aikanor Date: Wed, 10 Feb 2021 08:56:19 -0800 Subject: [PATCH 067/244] integration 14/ --- wilds/datasets/encodetfbs_dataset.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index dc76a366..327ffb3e 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -1,4 +1,4 @@ -import os +import os, time import torch import pandas as pd import numpy as np @@ -25,6 +25,7 @@ class EncodeTFBSDataset(WILDSDataset): """ def __init__(self, root_dir='data', download=False, split_scheme='official'): + itime = time.time() self._dataset_name = 'encode-tfbs' self._version = '1.0' self._download_url = 'https://worksheets.codalab.org/rest/bundles/0x8b3255e21e164cd98d3aeec09cd0bc26/contents/blob/' @@ -52,6 +53,7 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): self._seq_bp = {} for chrom in seq_arr: self._seq_bp[chrom] = seq_arr[chrom] + print(chrom, time.time() - itime) self._dnase_allcelltypes = {} for ct in self._all_celltypes: @@ -60,6 +62,7 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): self._dnase_allcelltypes[ct] = {} for chrom in self._seq_bp: self._dnase_allcelltypes[ct][chrom] = dnase_npz_contents[chrom] + print(ct, time.time() - itime) # Read in metadata dataframe from training+validation data train_regions_labeled = pd.read_csv(os.path.join(self._data_dir, 'labels/{}.train.labels.tsv.gz'.format(self._transcription_factor)), sep='\t') @@ -67,6 +70,7 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): training_df = train_regions_labeled[np.isin(train_regions_labeled['chr'], self._tr_chrs)] val_df = val_regions_labeled[np.isin(val_regions_labeled['chr'], self._te_chrs)] all_df = pd.concat([training_df, val_df]) + print(train_regions_labeled, time.time() - itime) # Filter by start/stop coordinate if needed # (TODO: remove for final version) @@ -150,7 +154,7 @@ def get_input(self, idx): interval_end = this_metadata['stop'] + flank_size dnase_this = _dnase_allcelltypes[this_metadata['celltype']][this_metadata['chr']][interval_start:interval_end] seq_this = _seq_bp[this_metadata['chr']][interval_start:interval_end] - return np.column_stack([seq_this, dnase_this]) + return torch.tensor(np.column_stack([seq_this, dnase_this])) def eval(self, y_pred, y_true, metadata): return self.standard_group_eval( From ccd6082cc594db36e3a94672f5e8b15a0ff38ffc Mon Sep 17 00:00:00 2001 From: aikanor Date: Wed, 10 Feb 2021 19:45:41 -0800 Subject: [PATCH 068/244] integration 15/ (refactor encodetfbs_dataset) --- examples/sbox_run_expt.ipynb | 2193 ++++++++++++++++++++++++++ wilds/datasets/encodetfbs_dataset.py | 102 +- 2 files changed, 2248 insertions(+), 47 deletions(-) create mode 100644 examples/sbox_run_expt.ipynb diff --git a/examples/sbox_run_expt.ipynb b/examples/sbox_run_expt.ipynb new file mode 100644 index 00000000..56a9f6a2 --- /dev/null +++ b/examples/sbox_run_expt.ipynb @@ -0,0 +1,2193 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# run_expt.py contents\n", + "\n", + "## 1) Preamble" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "import os, csv\n", + "import time\n", + "import argparse\n", + "import numpy as np, pandas as pd\n", + "import torch\n", + "import torch.nn as nn\n", + "import torchvision\n", + "import sys\n", + "from collections import defaultdict\n", + "\n", + "from wilds.common.data_loaders import get_train_loader, get_eval_loader\n", + "from wilds.common.grouper import CombinatorialGrouper\n", + "\n", + "from utils import set_seed, Logger, BatchLogger, log_config, ParseKwargs, load, initialize_wandb, log_group_data, parse_bool\n", + "from train import train, evaluate\n", + "from algorithms.initializer import initialize_algorithm\n", + "from transforms import initialize_transform\n", + "from configs.utils import populate_defaults\n", + "import configs.supported as supported" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "''' set default hyperparams in default_hyperparams.py '''\n", + "parser = argparse.ArgumentParser()\n", + "\n", + "# Required arguments\n", + "parser.add_argument('-d', '--dataset', choices=supported.datasets, required=True)\n", + "parser.add_argument('--algorithm', required=True, choices=supported.algorithms)\n", + "parser.add_argument('--root_dir', required=True,\n", + " help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).')\n", + "\n", + "# Dataset\n", + "parser.add_argument('--split_scheme', help='Identifies how the train/val/test split is constructed. Choices are dataset-specific.')\n", + "parser.add_argument('--dataset_kwargs', nargs='*', action=ParseKwargs, default={})\n", + "parser.add_argument('--download', default=False, type=parse_bool, const=True, nargs='?',\n", + " help='If true, tries to downloads the dataset if it does not exist in root_dir.')\n", + "parser.add_argument('--frac', type=float, default=1.0,\n", + " help='Convenience parameter that scales all dataset splits down to the specified fraction, for development purposes.')\n", + "\n", + "# Loaders\n", + "parser.add_argument('--loader_kwargs', nargs='*', action=ParseKwargs, default={})\n", + "parser.add_argument('--train_loader', choices=['standard', 'group'])\n", + "parser.add_argument('--uniform_over_groups', type=parse_bool, const=True, nargs='?')\n", + "parser.add_argument('--distinct_groups', type=parse_bool, const=True, nargs='?')\n", + "parser.add_argument('--n_groups_per_batch', type=int)\n", + "parser.add_argument('--batch_size', type=int)\n", + "parser.add_argument('--eval_loader', choices=['standard'], default='standard')\n", + "\n", + "# Model\n", + "parser.add_argument('--model', choices=supported.models)\n", + "parser.add_argument('--model_kwargs', nargs='*', action=ParseKwargs, default={},\n", + " help='keyword arguments for model initialization passed as key1=value1 key2=value2')\n", + "\n", + "# Transforms\n", + "parser.add_argument('--train_transform', choices=supported.transforms)\n", + "parser.add_argument('--eval_transform', choices=supported.transforms)\n", + "parser.add_argument('--target_resolution', nargs='+', type=int, help='target resolution. for example --target_resolution 224 224 for standard resnet.')\n", + "parser.add_argument('--resize_scale', type=float)\n", + "parser.add_argument('--max_token_length', type=int)\n", + "\n", + "# Objective\n", + "parser.add_argument('--loss_function', choices = supported.losses)\n", + "\n", + "# Algorithm\n", + "parser.add_argument('--groupby_fields', nargs='+')\n", + "parser.add_argument('--group_dro_step_size', type=float)\n", + "parser.add_argument('--coral_penalty_weight', type=float)\n", + "parser.add_argument('--irm_lambda', type=float)\n", + "parser.add_argument('--irm_penalty_anneal_iters', type=int)\n", + "parser.add_argument('--algo_log_metric')\n", + "\n", + "# Model selection\n", + "parser.add_argument('--val_metric')\n", + "parser.add_argument('--val_metric_decreasing', type=parse_bool, const=True, nargs='?')\n", + "\n", + "# Optimization\n", + "parser.add_argument('--n_epochs', type=int)\n", + "parser.add_argument('--optimizer', choices=supported.optimizers)\n", + "parser.add_argument('--lr', type=float)\n", + "parser.add_argument('--weight_decay', type=float)\n", + "parser.add_argument('--max_grad_norm', type=float)\n", + "parser.add_argument('--optimizer_kwargs', nargs='*', action=ParseKwargs, default={})\n", + "\n", + "# Scheduler\n", + "parser.add_argument('--scheduler', choices=supported.schedulers)\n", + "parser.add_argument('--scheduler_kwargs', nargs='*', action=ParseKwargs, default={})\n", + "parser.add_argument('--scheduler_metric_split', choices=['train', 'val'], default='val')\n", + "parser.add_argument('--scheduler_metric_name')\n", + "\n", + "# Evaluation\n", + "parser.add_argument('--evaluate_all_splits', type=parse_bool, const=True, nargs='?', default=True)\n", + "parser.add_argument('--eval_splits', nargs='+', default=[])\n", + "parser.add_argument('--eval_only', type=parse_bool, const=True, nargs='?', default=False)\n", + "parser.add_argument('--eval_epoch', default=None, type=int)\n", + "\n", + "# Misc\n", + "parser.add_argument('--device', type=int, default=0)\n", + "parser.add_argument('--seed', type=int, default=0)\n", + "parser.add_argument('--log_dir', default='./logs')\n", + "parser.add_argument('--log_every', default=50, type=int)\n", + "parser.add_argument('--save_step', type=int)\n", + "parser.add_argument('--save_best', type=parse_bool, const=True, nargs='?', default=True)\n", + "parser.add_argument('--save_last', type=parse_bool, const=True, nargs='?', default=True)\n", + "parser.add_argument('--no_group_logging', type=parse_bool, const=True, nargs='?')\n", + "parser.add_argument('--use_wandb', type=parse_bool, const=True, nargs='?', default=False)\n", + "parser.add_argument('--progress_bar', type=parse_bool, const=True, nargs='?', default=False)\n", + "parser.add_argument('--resume', type=parse_bool, const=True, nargs='?', default=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "argstr_camelyon = \"--dataset camelyon17 --algorithm ERM --root_dir data\"\n", + "config_camelyon = parser.parse_args(argstr_camelyon.split())\n", + "config_camelyon = populate_defaults(config_camelyon)\n", + "\n", + "argstr_encode = \"--dataset encode-tfbs --algorithm ERM --root_dir data\"\n", + "config_encode = parser.parse_args(argstr_encode.split())\n", + "config_encode = populate_defaults(config_encode)\n", + "\n", + "config = config_camelyon\n", + "#config = config_encode" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Dataset: camelyon17\n", + "Algorithm: ERM\n", + "Root dir: data\n", + "Split scheme: official\n", + "Dataset kwargs: {}\n", + "Download: False\n", + "Frac: 1.0\n", + "Loader kwargs: {'num_workers': 4, 'pin_memory': True}\n", + "Train loader: standard\n", + "Uniform over groups: False\n", + "Distinct groups: None\n", + "N groups per batch: 2\n", + "Batch size: 32\n", + "Eval loader: standard\n", + "Model: densenet121\n", + "Model kwargs: {'pretrained': False}\n", + "Train transform: image_base\n", + "Eval transform: image_base\n", + "Target resolution: (224, 224)\n", + "Resize scale: None\n", + "Max token length: None\n", + "Loss function: cross_entropy\n", + "Groupby fields: ['hospital']\n", + "Group dro step size: None\n", + "Coral penalty weight: 0.1\n", + "Irm lambda: 1.0\n", + "Irm penalty anneal iters: None\n", + "Algo log metric: accuracy\n", + "Val metric: acc_avg\n", + "Val metric decreasing: False\n", + "N epochs: 5\n", + "Optimizer: SGD\n", + "Lr: 0.001\n", + "Weight decay: 0.01\n", + "Max grad norm: None\n", + "Optimizer kwargs: {'momentum': 0.9}\n", + "Scheduler: None\n", + "Scheduler kwargs: {}\n", + "Scheduler metric split: val\n", + "Scheduler metric name: None\n", + "Evaluate all splits: True\n", + "Eval splits: []\n", + "Eval only: False\n", + "Eval epoch: None\n", + "Device: cuda:0\n", + "Seed: 0\n", + "Log dir: ./logs\n", + "Log every: 50\n", + "Save step: None\n", + "Save best: True\n", + "Save last: True\n", + "No group logging: False\n", + "Use wandb: False\n", + "Progress bar: False\n", + "Resume: False\n", + "\n" + ] + } + ], + "source": [ + "# set device\n", + "config.device = torch.device(\"cuda:\" + str(config.device)) if torch.cuda.is_available() else torch.device(\"cpu\")\n", + "\n", + "## Initialize logs\n", + "if os.path.exists(config.log_dir) and config.resume:\n", + " resume=True\n", + " mode='a'\n", + "elif os.path.exists(config.log_dir) and config.eval_only:\n", + " resume=False\n", + " mode='a'\n", + "else:\n", + " resume=False\n", + " mode='w'\n", + "\n", + "if not os.path.exists(config.log_dir):\n", + " os.makedirs(config.log_dir)\n", + "logger = Logger(os.path.join(config.log_dir, 'log.txt'), mode)\n", + "\n", + "# Record config\n", + "log_config(config, logger)\n", + "\n", + "# Set random seed\n", + "set_seed(config.seed)\n", + "\n", + "# Data\n", + "full_dataset = supported.datasets[config.dataset](\n", + " root_dir=config.root_dir,\n", + " download=config.download,\n", + " split_scheme=config.split_scheme,\n", + " **config.dataset_kwargs)\n", + "\n", + "# To implement data augmentation (i.e., have different transforms\n", + "# at training time vs. test time), modify these two lines:\n", + "train_transform = initialize_transform(\n", + " transform_name=config.train_transform,\n", + " config=config,\n", + " dataset=full_dataset)\n", + "eval_transform = initialize_transform(\n", + " transform_name=config.eval_transform,\n", + " config=config,\n", + " dataset=full_dataset)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2) Initialize dataset object" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "chr2 3.817250967025757\n", + "chr9 6.033524990081787\n", + "chr11 8.150986433029175\n", + "chr1 12.036555290222168\n", + "chr8 14.306443929672241\n", + "chr21 15.043241739273071\n", + "H1-hESC 21.61008930206299\n", + "HCT116 28.000329971313477\n", + "HeLa-S3 34.6184778213501\n", + "HepG2 41.089255809783936\n", + "K562 47.70136523246765\n", + "A549 54.22390341758728\n", + "GM12878 60.65142226219177\n", + " chr start stop A549 GM12878 H1-hESC HCT116 HeLa-S3 \\\n", + "0 chr10 600 800 U U U U U \n", + "1 chr10 650 850 U U U U U \n", + "2 chr10 700 900 U U U U U \n", + "3 chr10 750 950 U U U U U \n", + "4 chr10 800 1000 U U U U U \n", + "... ... ... ... ... ... ... ... ... \n", + "51676731 chrX 155269750 155269950 U U U U U \n", + "51676732 chrX 155269800 155270000 U U U U U \n", + "51676733 chrX 155269850 155270050 U U U U U \n", + "51676734 chrX 155269900 155270100 U U U U U \n", + "51676735 chrX 155269950 155270150 U U U U U \n", + "\n", + " HepG2 K562 \n", + "0 U U \n", + "1 U U \n", + "2 U U \n", + "3 U U \n", + "4 U U \n", + "... ... ... \n", + "51676731 U U \n", + "51676732 U U \n", + "51676733 U U \n", + "51676734 U U \n", + "51676735 U U \n", + "\n", + "[51676736 rows x 10 columns] 130.07371044158936\n" + ] + } + ], + "source": [ + "import os, time\n", + "import torch\n", + "import pandas as pd\n", + "import numpy as np\n", + "from wilds.datasets.wilds_dataset import WILDSDataset\n", + "from wilds.common.grouper import CombinatorialGrouper\n", + "from wilds.common.metrics.all_metrics import Accuracy\n", + "\n", + "root_dir='data'\n", + "download=False\n", + "split_scheme='official'\n", + "\n", + "itime = time.time()\n", + "_dataset_name = 'encode-tfbs'\n", + "_version = '1.0'\n", + "_download_url = 'https://worksheets.codalab.org/rest/bundles/0x8b3255e21e164cd98d3aeec09cd0bc26/contents/blob/'\n", + "_data_dir = 'data/encode-tfbs_v1.0'\n", + "_y_size = 1\n", + "_n_classes = 2\n", + "\n", + "# _train_chroms = ['chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX']\n", + "_train_chroms = ['chr2', 'chr9', 'chr11']\n", + "_test_chroms = ['chr1', 'chr8', 'chr21']\n", + "_transcription_factor = 'MAX'\n", + "_train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']\n", + "_val_celltype = ['A549']\n", + "_test_celltype = ['GM12878']\n", + "_all_chroms = _train_chroms + _test_chroms\n", + "_all_celltypes = _train_celltypes + _val_celltype + _test_celltype\n", + "\n", + "_metadata_map = {}\n", + "_metadata_map['chr'] = _all_chroms #['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX']\n", + "_metadata_map['celltype'] = _all_celltypes\n", + "\n", + "# Get the splits\n", + "if split_scheme=='official':\n", + " split_scheme = 'standard'\n", + "\n", + "_split_scheme = split_scheme\n", + "_split_dict = {\n", + " 'train': 0,\n", + " 'id_val': 1,\n", + " 'test': 2,\n", + " 'val': 3\n", + "}\n", + "_split_names = {\n", + " 'train': 'Train',\n", + " 'id_val': 'Validation (ID)',\n", + " 'test': 'Test',\n", + " 'val': 'Validation (OOD)',\n", + "}\n", + "\n", + "# Load sequence and DNase features\n", + "sequence_filename = os.path.join(_data_dir, 'sequence.npz')\n", + "seq_arr = np.load(sequence_filename)\n", + "_seq_bp = {}\n", + "for chrom in _all_chroms: #seq_arr:\n", + " _seq_bp[chrom] = seq_arr[chrom]\n", + " print(chrom, time.time() - itime)\n", + "\n", + "_dnase_allcelltypes = {}\n", + "for ct in _all_celltypes:\n", + " dnase_filename = os.path.join(_data_dir, '{}_dnase.npz'.format(ct))\n", + " dnase_npz_contents = np.load(dnase_filename)\n", + " _dnase_allcelltypes[ct] = {}\n", + " for chrom in _all_chroms: #_seq_bp:\n", + " _dnase_allcelltypes[ct][chrom] = dnase_npz_contents[chrom]\n", + " print(ct, time.time() - itime)\n", + "\n", + "# Read in metadata dataframe from training+validation data\n", + "train_regions_labeled = pd.read_csv(os.path.join(_data_dir, 'labels/{}.train.labels.tsv.gz'.format(_transcription_factor)), sep='\\t')\n", + "val_regions_labeled = pd.read_csv(os.path.join(_data_dir, 'labels/{}.val.labels.tsv.gz'.format(_transcription_factor)), sep='\\t')\n", + "training_df = train_regions_labeled[np.isin(train_regions_labeled['chr'], _train_chroms)]\n", + "val_df = val_regions_labeled[np.isin(val_regions_labeled['chr'], _test_chroms)]\n", + "all_df = pd.concat([training_df, val_df])\n", + "\n", + "# Filter by start/stop coordinate if needed (TODO: remove for final version)\n", + "filter_msk = all_df['start'] >= 0\n", + "filter_msk = all_df['start']%1000 == 0\n", + "all_df = all_df[filter_msk]\n", + "\n", + "pd_list = []\n", + "for ct in _all_celltypes:\n", + " tc_chr = all_df[['chr', 'start', 'stop', ct]]\n", + " tc_chr.columns = ['chr', 'start', 'stop', 'y']\n", + " tc_chr.insert(len(tc_chr.columns), 'celltype', ct)\n", + " pd_list.append(tc_chr)\n", + "metadata_df = pd.concat(pd_list)" + ] + }, + { + "cell_type": "code", + "execution_count": 131, + "metadata": {}, + "outputs": [], + "source": [ + "# Get the y values, and remove ambiguous labels by default.\n", + "y_array = metadata_df['y'].replace({'U': 0, 'B': 1, 'A': -1}).values\n", + "non_ambig_mask = (y_array != -1)\n", + "metadata_df['y'] = y_array\n", + "_metadata_df = metadata_df[non_ambig_mask]\n", + "\n", + "train_regions_mask = np.isin(_metadata_df['chr'], _train_chroms)\n", + "val_regions_mask = np.isin(_metadata_df['chr'], _test_chroms)\n", + "train_celltype_mask = np.isin(_metadata_df['celltype'], _train_celltypes)\n", + "val_celltype_mask = np.isin(_metadata_df['celltype'], _val_celltype)\n", + "test_celltype_mask = np.isin(_metadata_df['celltype'], _test_celltype)\n", + "\n", + "split_array = -1*np.ones(_metadata_df.shape[0]).astype(int)\n", + "split_array[np.logical_and(train_regions_mask, train_celltype_mask)] = _split_dict['train']\n", + "split_array[np.logical_and(val_regions_mask, test_celltype_mask)] = _split_dict['test']\n", + "# Validate using test chr, either using a designated validation cell line ('val') or a training cell line ('id_val')\n", + "split_array[np.logical_and(val_regions_mask, val_celltype_mask)] = _split_dict['val']\n", + "split_array[np.logical_and(val_regions_mask, train_celltype_mask)] = _split_dict['id_val']\n", + "\n", + "if _split_scheme=='standard':\n", + " _metadata_df.insert(len(_metadata_df.columns), 'split', split_array)\n", + "else:\n", + " raise ValueError(f'Split scheme {_split_scheme} not recognized')\n", + "\n", + "_metadata_df = _metadata_df[_metadata_df['split'] != -1]\n", + "_split_array = _metadata_df['split'].values\n", + "\n", + "chr_ints = _metadata_df['chr'].replace(dict( [(y, x) for x, y in enumerate(_metadata_map['chr'])] )).values\n", + "celltype_ints = _metadata_df['celltype'].replace(dict( [(y, x) for x, y in enumerate(_metadata_map['celltype'])] )).values\n", + "_y_array = torch.LongTensor(np.array(_metadata_df['y']))\n", + "\n", + "_metadata_array = torch.stack(\n", + " (torch.LongTensor(chr_ints), \n", + " torch.LongTensor(celltype_ints), \n", + " _y_array),\n", + " dim=1)\n", + "_metadata_fields = ['chr', 'celltype', 'y']\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Initialize dataset object" + ] + }, + { + "cell_type": "code", + "execution_count": 138, + "metadata": {}, + "outputs": [], + "source": [ + "import os, time\n", + "import torch\n", + "import pandas as pd\n", + "import numpy as np\n", + "from wilds.datasets.wilds_dataset import WILDSDataset\n", + "from wilds.common.grouper import CombinatorialGrouper\n", + "from wilds.common.metrics.all_metrics import Accuracy\n", + "\n", + "class EncodeTFBSDataset(WILDSDataset):\n", + " \"\"\"\n", + " ENCODE-DREAM-wilds dataset of transcription factor binding sites. \n", + " This is a subset of the dataset from the ENCODE-DREAM in vivo Transcription Factor Binding Site Prediction Challenge. \n", + " \n", + " Input (x):\n", + " 1000-base-pair regions of sequence with a quantified chromatin accessibility readout.\n", + "\n", + " Label (y):\n", + " y is binary. It is 1 if the central 200bp region is bound by the transcription factor MAX, and 0 otherwise.\n", + "\n", + " Metadata:\n", + " Each sequence is annotated with the celltype of origin (a string) and the chromosome of origin (a string).\n", + " \n", + " Website:\n", + " https://www.synapse.org/#!Synapse:syn6131484\n", + " \"\"\"\n", + "\n", + " def __init__(self, root_dir='data', download=False, split_scheme='official'):\n", + " itime = time.time()\n", + " self._dataset_name = 'encode-tfbs'\n", + " self._version = '1.0'\n", + " self._download_url = 'https://worksheets.codalab.org/rest/bundles/0x8b3255e21e164cd98d3aeec09cd0bc26/contents/blob/'\n", + " self._data_dir = self.initialize_data_dir(root_dir, download)\n", + " self._y_size = 1\n", + " self._n_classes = 2\n", + " \n", + " # self._train_chroms = ['chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX']\n", + " self._train_chroms = ['chr2', 'chr9', 'chr11']\n", + " self._test_chroms = ['chr1', 'chr8', 'chr21']\n", + " self._transcription_factor = 'MAX'\n", + " self._train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']\n", + " self._val_celltype = ['A549']\n", + " self._test_celltype = ['GM12878']\n", + " self._all_chroms = self._train_chroms + self._test_chroms\n", + " self._all_celltypes = self._train_celltypes + self._val_celltype + self._test_celltype\n", + " \n", + " self._metadata_map = {}\n", + " self._metadata_map['chr'] = self._all_chroms #['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX']\n", + " self._metadata_map['celltype'] = self._all_celltypes\n", + " \n", + " # Get the splits\n", + " if split_scheme=='official':\n", + " split_scheme = 'standard'\n", + " \n", + " self._split_scheme = split_scheme\n", + " self._split_dict = {\n", + " 'train': 0,\n", + " 'id_val': 1,\n", + " 'test': 2,\n", + " 'val': 3\n", + " }\n", + " self._split_names = {\n", + " 'train': 'Train',\n", + " 'id_val': 'Validation (ID)',\n", + " 'test': 'Test',\n", + " 'val': 'Validation (OOD)',\n", + " }\n", + " \n", + " # Load sequence and DNase features\n", + " sequence_filename = os.path.join(self._data_dir, 'sequence.npz')\n", + " seq_arr = np.load(sequence_filename)\n", + " self._seq_bp = {}\n", + " for chrom in self._all_chroms: #seq_arr:\n", + " self._seq_bp[chrom] = seq_arr[chrom]\n", + " print(chrom, time.time() - itime)\n", + " \n", + " self._dnase_allcelltypes = {}\n", + " for ct in self._all_celltypes:\n", + " dnase_filename = os.path.join(self._data_dir, '{}_dnase.npz'.format(ct))\n", + " dnase_npz_contents = np.load(dnase_filename)\n", + " self._dnase_allcelltypes[ct] = {}\n", + " for chrom in self._all_chroms: #self._seq_bp:\n", + " self._dnase_allcelltypes[ct][chrom] = dnase_npz_contents[chrom]\n", + " print(ct, time.time() - itime)\n", + " \n", + " # Read in metadata dataframe from training+validation data\n", + " train_regions_labeled = pd.read_csv(os.path.join(self._data_dir, 'labels/{}.train.labels.tsv.gz'.format(self._transcription_factor)), sep='\\t')\n", + " val_regions_labeled = pd.read_csv(os.path.join(self._data_dir, 'labels/{}.val.labels.tsv.gz'.format(self._transcription_factor)), sep='\\t')\n", + " training_df = train_regions_labeled[np.isin(train_regions_labeled['chr'], self._train_chroms)]\n", + " val_df = val_regions_labeled[np.isin(val_regions_labeled['chr'], self._test_chroms)]\n", + " all_df = pd.concat([training_df, val_df])\n", + " \n", + " # Filter by start/stop coordinate if needed (TODO: remove for final version)\n", + " filter_msk = all_df['start'] >= 0\n", + " filter_msk = all_df['start']%1000 == 0\n", + " all_df = all_df[filter_msk]\n", + " \n", + " pd_list = []\n", + " for ct in self._all_celltypes:\n", + " tc_chr = all_df[['chr', 'start', 'stop', ct]]\n", + " tc_chr.columns = ['chr', 'start', 'stop', 'y']\n", + " tc_chr.insert(len(tc_chr.columns), 'celltype', ct)\n", + " pd_list.append(tc_chr)\n", + " metadata_df = pd.concat(pd_list)\n", + " \n", + " # Get the y values, and remove ambiguous labels by default.\n", + " y_array = metadata_df['y'].replace({'U': 0, 'B': 1, 'A': -1}).values\n", + " non_ambig_mask = (y_array != -1)\n", + " metadata_df['y'] = y_array\n", + " self._metadata_df = metadata_df[non_ambig_mask]\n", + " \n", + " train_regions_mask = np.isin(self._metadata_df['chr'], self._train_chroms)\n", + " val_regions_mask = np.isin(self._metadata_df['chr'], self._test_chroms)\n", + " train_celltype_mask = np.isin(self._metadata_df['celltype'], self._train_celltypes)\n", + " val_celltype_mask = np.isin(self._metadata_df['celltype'], self._val_celltype)\n", + " test_celltype_mask = np.isin(self._metadata_df['celltype'], self._test_celltype)\n", + " \n", + " split_array = -1*np.ones(self._metadata_df.shape[0]).astype(int)\n", + " split_array[np.logical_and(train_regions_mask, train_celltype_mask)] = self._split_dict['train']\n", + " split_array[np.logical_and(val_regions_mask, test_celltype_mask)] = self._split_dict['test']\n", + " # Validate using test chr, either using a designated validation cell line ('val') or a training cell line ('id_val')\n", + " split_array[np.logical_and(val_regions_mask, val_celltype_mask)] = self._split_dict['val']\n", + " split_array[np.logical_and(val_regions_mask, train_celltype_mask)] = self._split_dict['id_val']\n", + " \n", + " if self._split_scheme=='standard':\n", + " self._metadata_df.insert(len(self._metadata_df.columns), 'split', split_array)\n", + " else:\n", + " raise ValueError(f'Split scheme {self._split_scheme} not recognized')\n", + " \n", + " self._metadata_df = self._metadata_df[self._metadata_df['split'] != -1]\n", + " self._split_array = self._metadata_df['split'].values\n", + " \n", + " chr_ints = self._metadata_df['chr'].replace(dict( [(y, x) for x, y in enumerate(self._metadata_map['chr'])] )).values\n", + " celltype_ints = self._metadata_df['celltype'].replace(dict( [(y, x) for x, y in enumerate(self._metadata_map['celltype'])] )).values\n", + " self._y_array = torch.LongTensor(np.array(self._metadata_df['y']))\n", + " \n", + " self._metadata_array = torch.stack(\n", + " (torch.LongTensor(chr_ints), \n", + " torch.LongTensor(celltype_ints), \n", + " self._y_array),\n", + " dim=1)\n", + " self._metadata_fields = ['chr', 'celltype', 'y']\n", + " \n", + " self._eval_grouper = CombinatorialGrouper(\n", + " dataset=self,\n", + " groupby_fields=['celltype'])\n", + " \n", + " self._metric = Accuracy()\n", + " \n", + " super().__init__(root_dir, download, split_scheme)\n", + "\n", + " def get_input(self, idx):\n", + " \"\"\"\n", + " Returns x for a given idx.\n", + " Computes this from: \n", + " (1) sequence features in self._seq_bp\n", + " (2) DNase features in self._dnase_allcelltypes\n", + " (3) Metadata for the index (location along the genome with 200bp window width)\n", + " \"\"\"\n", + " this_metadata = self._metadata_df.iloc[idx, :]\n", + " flank_size = 400\n", + " interval_start = this_metadata['start'] - flank_size\n", + " interval_end = this_metadata['stop'] + flank_size\n", + " dnase_this = self._dnase_allcelltypes[this_metadata['celltype']][this_metadata['chr']][interval_start:interval_end]\n", + " seq_this = self._seq_bp[this_metadata['chr']][interval_start:interval_end]\n", + " return torch.tensor(np.column_stack([seq_this, dnase_this]))\n", + "\n", + " def eval(self, y_pred, y_true, metadata):\n", + " return self.standard_group_eval(\n", + " self._metric,\n", + " self._eval_grouper,\n", + " y_pred, y_true, metadata)" + ] + }, + { + "cell_type": "code", + "execution_count": 139, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "chr2 3.836134910583496\n", + "chr9 6.034452438354492\n", + "chr11 8.16244888305664\n", + "chr1 12.12421178817749\n", + "chr8 14.44963550567627\n", + "chr21 15.212148189544678\n", + "H1-hESC 21.892271518707275\n", + "HCT116 28.37229895591736\n", + "HeLa-S3 35.18828296661377\n", + "HepG2 41.83891773223877\n", + "K562 48.590251445770264\n", + "A549 55.3311812877655\n", + "GM12878 61.93817687034607\n" + ] + } + ], + "source": [ + "full_dataset_encode = EncodeTFBSDataset(\n", + " root_dir=config.root_dir,\n", + " download=config.download,\n", + " split_scheme=config.split_scheme,\n", + " **config.dataset_kwargs)" + ] + }, + { + "cell_type": "code", + "execution_count": 140, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "<__main__.EncodeTFBSDataset at 0x7fe6b69d33a0>" + ] + }, + "execution_count": 140, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "full_dataset_encode" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(array([0, 1]), array([227977, 227977]))" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([1, 1, 1, ..., 0, 0, 0]) torch.Size([455954])\n" + ] + }, + { + "data": { + "text/plain": [ + "['patches/patient_004_node_4/patch_patient_004_node_4_x_3328_y_21792.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3200_y_22272.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3168_y_22272.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3328_y_21760.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3232_y_22240.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3168_y_22240.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3136_y_22208.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_2656_y_18880.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3136_y_22240.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3296_y_21856.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3296_y_21792.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3360_y_21824.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3360_y_21760.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3296_y_21824.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3328_y_21824.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_2688_y_18912.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3168_y_22176.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_2688_y_18816.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3200_y_22176.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3168_y_22208.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_2688_y_18880.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3296_y_21760.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_2656_y_18848.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3136_y_22272.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3264_y_21856.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3264_y_21824.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_2688_y_18848.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3264_y_21792.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_2656_y_18944.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3200_y_22208.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3232_y_22208.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3200_y_22240.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_2688_y_18944.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3360_y_21792.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_2656_y_18912.png',\n", + " 'patches/patient_004_node_4/patch_patient_004_node_4_x_2656_y_18816.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12960_y_36128.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_35968.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_36512.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_36064.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13312_y_36320.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_35968.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12128_y_34560.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_36384.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_36192.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13120_y_35936.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_36480.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_35680.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_35648.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_36032.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_36416.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12992_y_36160.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12160_y_34752.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_36288.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_36160.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_35744.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12672_y_35840.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_36320.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12896_y_35648.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_35904.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12864_y_36512.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13120_y_35904.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_35872.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_36192.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_36224.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_36416.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_35872.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13152_y_35872.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_35904.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_36224.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12672_y_36192.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_35968.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_35680.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_35648.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_36096.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_35616.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_36224.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_35776.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_36064.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_35936.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_35808.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_36224.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12672_y_36320.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_36032.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_36128.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12128_y_34752.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_36288.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_35904.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_36160.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_36192.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13184_y_35968.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_35808.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13184_y_35936.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12800_y_36192.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12096_y_34560.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12096_y_34528.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13120_y_36064.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_36032.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12128_y_34784.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12672_y_36288.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_36224.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12672_y_36256.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_35584.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_35616.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12096_y_34720.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12992_y_36352.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13024_y_36160.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12128_y_34624.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_36416.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_36384.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12640_y_36352.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_35712.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12960_y_36096.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12960_y_36352.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_36384.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13024_y_36352.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12800_y_36448.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_36448.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12640_y_36288.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_36128.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12064_y_34560.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_35808.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12096_y_34592.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12960_y_36448.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12864_y_35616.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12672_y_36224.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_36000.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12992_y_35840.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_36096.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12992_y_36416.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12864_y_36352.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12064_y_34720.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12448_y_36064.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_35648.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_36352.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12448_y_36096.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13120_y_35872.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_36192.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_36256.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_35840.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_36128.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_35872.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_36064.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12992_y_36032.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12864_y_36320.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12640_y_36224.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_35968.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12640_y_35808.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_36064.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12896_y_36224.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_35744.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12992_y_36128.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_36352.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12992_y_36096.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13312_y_36224.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_35872.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12096_y_34624.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_35872.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12992_y_36384.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_35936.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12448_y_35936.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_36192.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12128_y_34720.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_36096.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_36352.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_36128.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12448_y_35488.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12800_y_36512.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13376_y_36192.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_36096.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_36160.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13376_y_36128.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13024_y_36064.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12864_y_36256.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12448_y_36128.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12800_y_36480.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_36480.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12864_y_36480.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12800_y_36544.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_36032.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13024_y_36384.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12672_y_36096.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13024_y_36320.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_36224.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12896_y_36256.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13344_y_36288.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12128_y_34656.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12864_y_36544.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_36064.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_35936.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_35776.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12992_y_36064.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13120_y_36000.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_35744.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12960_y_36416.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12864_y_36224.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_36192.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12640_y_35840.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12640_y_35872.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_36160.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12160_y_34720.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_36192.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13024_y_35872.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_36096.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_36480.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12128_y_34592.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_35936.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_36096.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13376_y_36160.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12864_y_36288.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_35872.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12064_y_34688.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_36160.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12896_y_35616.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_36096.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_35616.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_35840.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_36000.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_36032.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13024_y_35840.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_35808.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_35904.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13024_y_36032.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12096_y_34752.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_36032.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13120_y_35968.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12064_y_34624.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12864_y_35680.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12672_y_35808.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12160_y_34688.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12448_y_35968.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_36000.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12800_y_36416.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12800_y_36288.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_36352.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13024_y_36096.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12960_y_36064.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12128_y_34688.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13152_y_35936.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13152_y_36000.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12800_y_36224.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13344_y_36192.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_36096.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_35744.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12064_y_34592.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12096_y_34688.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_35840.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12992_y_36320.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_35584.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12096_y_34656.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13024_y_36416.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_35680.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_36000.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_35936.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12096_y_34784.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_35936.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_35648.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12864_y_35712.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_36416.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_36256.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_36416.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_36192.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12640_y_36320.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12896_y_36288.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_36064.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_36128.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13024_y_36128.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_36096.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12672_y_36128.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12800_y_36384.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12160_y_34656.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12992_y_36448.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12800_y_36352.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_36000.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_36320.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_35840.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12064_y_34528.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_35936.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_36416.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_36320.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12992_y_35872.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_36256.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_36544.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13120_y_36032.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_36224.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_35904.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_36064.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_36160.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_36256.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_35488.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_36448.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_36288.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_35936.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_36096.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_36256.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_36128.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13280_y_36288.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_35968.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_35680.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_36160.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12864_y_35648.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12960_y_36384.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_36288.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_36128.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_36512.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12672_y_36160.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13344_y_36224.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_36128.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_36128.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_36480.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_36224.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12640_y_36256.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13184_y_35904.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_36256.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_35776.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_36160.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_35840.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12640_y_35904.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_35968.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_35968.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13344_y_36160.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_35968.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_35872.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_36032.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13312_y_36288.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_35776.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_35808.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13184_y_35872.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12672_y_35904.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13152_y_35968.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12448_y_35904.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12800_y_36256.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_36288.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_36448.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13024_y_36000.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_36448.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12448_y_36000.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12448_y_36032.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12672_y_35872.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12800_y_36320.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_36384.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13152_y_35904.png',\n", + " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_36256.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16416.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16544.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16384_y_24352.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18624_y_16544.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16288_y_24448.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_17216.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_15968.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_17056.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18400_y_16608.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18368_y_16032.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16544_y_24768.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18464_y_16288.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_16672.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_16672.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17024_y_24512.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18784_y_16224.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17664_y_15488.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_17440.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17536_y_15552.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16384.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_15904.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16320.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_17472.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16864.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18368_y_16192.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17632_y_15104.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_17024.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_16704.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_15648.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_16416.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_16832.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16800.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_17216.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_15616.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_16288.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_15680.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_16480.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_16480.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_16896.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16320_y_25056.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18720_y_16448.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17344_y_28160.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_16704.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_16672.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16480.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_16928.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17664_y_15520.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16480_y_24992.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18368_y_16608.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_17536.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18368_y_16640.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_17216.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_15872.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19264_y_16352.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17088_y_24448.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18432_y_16192.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16768.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18720_y_16160.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16480.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17088_y_24512.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19360_y_16736.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_17024.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16160.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18432_y_16352.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_17088.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16160.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19264_y_16256.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16896.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16544.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18784_y_16672.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_16416.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16320.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_16640.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_17152.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_17248.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19264_y_16640.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_17024.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16320_y_24448.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18784_y_16800.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_16448.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16768_y_25152.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16736_y_25152.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16704_y_24640.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16064.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17600_y_15488.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_17376.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16608_y_24512.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16864_y_24288.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17120_y_25088.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16544_y_24800.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18304_y_16064.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17056_y_24512.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16096.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16384.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17280_y_24768.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16288.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_17056.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16768_y_24928.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16448.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_16320.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_17024.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_17184.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16480.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16288_y_25024.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16352.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_15648.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16672.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18720_y_16608.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16320.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16352_y_25024.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_17184.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_17056.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16544.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_16480.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16288.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_15712.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18720_y_16736.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16576.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_17152.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18624_y_16480.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_17216.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16544.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16800.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18432_y_15744.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16096.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16256_y_24448.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16480_y_24832.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16608_y_25024.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_16832.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16352.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18624_y_16448.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_15680.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16480_y_24800.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16992_y_24448.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18784_y_16768.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_16352.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16224.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_17280.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16832_y_24288.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_17408.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16192.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19264_y_16960.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16608_y_24608.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16000.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_17088.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17248_y_24768.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_16640.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16832.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_17504.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_16800.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17632_y_15520.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16832.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_16768.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_16992.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_17248.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_17504.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16256_y_24480.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_16000.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_16576.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16320_y_24608.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16416.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17280_y_28064.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16032.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17536_y_25120.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_17152.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_16192.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17408_y_24896.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19296_y_16608.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_16096.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18432_y_16640.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19264_y_16992.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19264_y_16704.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16864.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_17056.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18624_y_17120.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16768.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16352_y_24512.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_15648.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_16896.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18432_y_16704.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_17088.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16256.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19360_y_16672.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16768_y_25184.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18720_y_16640.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_16288.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_16608.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_17184.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16512.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16576.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16160_y_24864.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_16032.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16608_y_24640.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16416.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16640_y_25440.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17280_y_28192.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16800_y_25120.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_16608.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_17344.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16224.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16384.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_16832.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_16736.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18368_y_16672.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16416.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16608_y_24352.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_16480.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18400_y_16736.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17600_y_15392.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16320_y_24480.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_15968.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16640_y_24576.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_16736.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16992_y_24960.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16672_y_24512.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16512.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_16640.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_17120.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18336_y_16064.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_16096.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16128_y_24896.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16512.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_17440.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_17248.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16544_y_25024.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16064.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_17440.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_17536.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17440_y_24928.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_17216.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16864_y_24352.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16224.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_16768.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_17184.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_16064.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18368_y_16000.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_17344.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18720_y_15680.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16672_y_24928.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18464_y_16672.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17536_y_15520.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16704_y_24608.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16576.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16672.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16736.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18464_y_16608.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16544.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18784_y_16416.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17312_y_28128.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_17312.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_17120.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16448_y_24768.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_16128.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16640_y_24640.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16480_y_25024.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_17152.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_17120.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_17280.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_15584.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_15648.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17408_y_24928.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16832.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17344_y_28128.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_17120.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16736.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16512.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_17184.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_17440.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_16864.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18400_y_16672.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18784_y_16960.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18784_y_15968.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16512.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16800.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_15680.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_17024.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18720_y_16672.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16672.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_15616.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18784_y_16192.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18336_y_16160.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_17056.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16800.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16864.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17024_y_24960.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16640.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16320_y_24640.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17664_y_15328.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_17408.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_15712.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16736_y_24608.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_16704.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_16512.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16928.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_17088.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_16480.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16704.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_16960.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16960.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16096.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16448_y_25024.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_16384.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16832.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17376_y_24896.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16736.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_15648.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_16064.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_15584.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16192.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_16544.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_16864.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16768_y_25120.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16608.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_17248.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16384_y_24640.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_17184.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16512.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_16576.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16736.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_17088.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_17376.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_17088.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_16544.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_16064.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_17056.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19264_y_16480.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16512_y_24352.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17568_y_15488.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18464_y_16128.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16992.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16704.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17664_y_15456.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16480_y_25056.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_17088.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16256.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19296_y_16832.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_16768.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16928.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_17344.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_17120.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16608.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_16768.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16352_y_24608.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17664_y_15104.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16320.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16800_y_24928.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16672_y_24960.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17696_y_15488.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_15744.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_17344.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16704_y_24960.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17184_y_25120.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_16064.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16992_y_24480.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_15776.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16448.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18784_y_15744.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_16672.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_16096.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19328_y_16800.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17856_y_15296.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_16448.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18720_y_15712.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_15904.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16384_y_24512.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16992.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16512.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_16640.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16096.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_16736.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16512_y_24800.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_17216.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_17056.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16608_y_25120.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_17280.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_16576.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16768.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16416_y_24320.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_17376.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16576_y_24768.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18720_y_17440.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17312_y_28032.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16448.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_16992.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17568_y_25216.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16320.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_17120.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_16256.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19264_y_16512.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_17024.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17056_y_24480.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16832_y_24864.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16704.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18368_y_16736.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16288_y_24416.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_16704.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16608_y_24672.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16192.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19264_y_16768.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18720_y_16256.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16960.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_17472.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16320.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_15776.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18784_y_17440.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16192.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_17216.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16544.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19264_y_16576.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16064.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_17376.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16512.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17152_y_25088.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_16608.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_16032.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_16448.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_16544.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16576.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_16768.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_16224.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_17312.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16608.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_17056.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_16448.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19264_y_16384.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18624_y_16416.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_16416.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16576_y_24608.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_16512.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_17248.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16896.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_17056.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17248_y_24800.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_16928.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16896.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16416.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16000.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16640_y_24384.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_16416.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17632_y_15072.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_16480.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_16800.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16608.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_17408.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_16512.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17600_y_15136.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_16800.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_15936.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17600_y_15072.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_16000.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16480.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18624_y_16640.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_16640.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_16000.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16864_y_24320.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_17024.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_16672.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16800_y_25152.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16608.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16896_y_24576.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16480.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_17024.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_16672.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_17440.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16096.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16384.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_16384.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16352.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16672.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17664_y_15200.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_16448.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17120_y_25152.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17568_y_15392.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16416_y_24352.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16480.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16512_y_24320.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16096.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_17120.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16480_y_24352.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_16448.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_16960.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_16960.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_17248.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16672.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16576.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_16640.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16576_y_25024.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16032.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_16704.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_15648.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16448_y_24384.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16608_y_25088.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17568_y_25152.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_17216.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18336_y_16640.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16928.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_16576.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17664_y_15424.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16992.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16288.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_16448.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17664_y_15072.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16992.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18624_y_17280.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18400_y_16352.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16352.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_16512.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17696_y_15232.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_17152.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_16640.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_16288.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17664_y_15136.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_16576.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18464_y_15712.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16576.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_15712.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_16704.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18432_y_15680.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19296_y_16544.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_17248.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16544_y_25056.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18624_y_16832.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16288_y_24480.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16384.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18624_y_17248.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16128.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_17440.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17696_y_15520.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16512_y_25024.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16064.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18432_y_16128.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_17184.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_16416.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17280_y_24800.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18432_y_16256.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18624_y_16096.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19328_y_16832.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16384_y_24320.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_16736.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17600_y_25216.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_17248.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18720_y_16576.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_17280.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_16768.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_16672.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16960.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16768.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16640_y_25408.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16736_y_24640.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16288_y_25056.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_17152.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_16224.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_17152.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18720_y_16480.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_17024.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16928.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16736_y_24576.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16640_y_25056.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16512_y_24768.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16640.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16736.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_15680.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16896_y_24320.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_16864.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16896.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18464_y_16448.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17056_y_24448.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16704.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_16160.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16672_y_25088.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16864.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17600_y_15360.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_16032.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18368_y_16704.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16608.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_15936.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18464_y_17088.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17632_y_15200.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_15584.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19328_y_16768.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16736.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_16448.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16800.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_16416.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16288_y_25088.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16480_y_24768.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19264_y_16832.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_17408.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16704.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18464_y_16512.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16736.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_17312.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16288_y_24512.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_17408.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_17024.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16640_y_24608.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16128.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16672_y_24640.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17152_y_25120.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17024_y_24928.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16672.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17568_y_25120.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16384.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16512.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_17184.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16960.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16608_y_24384.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_16352.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_16160.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17632_y_15488.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_15648.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_17056.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_15744.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16320_y_25024.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16736_y_24672.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17632_y_15424.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_16800.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18464_y_16064.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16608_y_24992.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_17120.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16608.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16640.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19296_y_16736.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16896.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18304_y_16032.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17632_y_15136.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_17056.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_16128.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_16032.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18624_y_17408.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_16544.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16608.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18464_y_17024.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_16736.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17568_y_15360.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16640.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16416.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_17280.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_17088.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16544.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_15968.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_16288.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_17440.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16672_y_24608.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16448_y_24352.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18336_y_16032.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_16384.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19360_y_16768.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18624_y_17184.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16832.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16992.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_17152.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16160.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_17504.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_15936.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16608_y_24960.png',\n", + " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_16672.png',\n", + " ...]" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "print(full_dataset._y_array, full_dataset._y_array.shape)\n", + "print(np.unique(full_dataset.y_array.numpy(), return_counts=True))\n", + "print(np.unique(full_dataset._metadata_df['split'], return_counts=True))\n", + "\n", + "#full_dataset._input_array" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# pd.read_csv(os.path.join('data/camelyon17_v1.0/metadata.csv'), index_col=0, dtype={'patient': 'str'})" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "import copy\n", + "full_dataset_camelyon17 = copy.deepcopy(full_dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "image_base None\n" + ] + } + ], + "source": [ + "supported.datasets[config_encode.dataset]\n", + "print(config_camelyon.train_transform, config_encode.train_transform)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_grouper = CombinatorialGrouper(\n", + " dataset=full_dataset,\n", + " groupby_fields=config.groupby_fields)\n", + "\n", + "datasets = defaultdict(dict)\n", + "for split in full_dataset.split_dict.keys():\n", + " if split=='train':\n", + " transform = train_transform\n", + " verbose = True\n", + " elif split == 'val':\n", + " transform = eval_transform\n", + " verbose = True\n", + " else:\n", + " transform = eval_transform\n", + " verbose = False\n", + " # Get subset\n", + " datasets[split]['dataset'] = full_dataset.get_subset(\n", + " split,\n", + " frac=config.frac,\n", + " transform=transform)\n", + "\n", + " if split == 'train':\n", + " datasets[split]['loader'] = get_train_loader(\n", + " loader=config.train_loader,\n", + " dataset=datasets[split]['dataset'],\n", + " batch_size=config.batch_size,\n", + " uniform_over_groups=config.uniform_over_groups,\n", + " grouper=train_grouper,\n", + " distinct_groups=config.distinct_groups,\n", + " n_groups_per_batch=config.n_groups_per_batch,\n", + " **config.loader_kwargs)\n", + " else:\n", + " datasets[split]['loader'] = get_eval_loader(\n", + " loader=config.eval_loader,\n", + " dataset=datasets[split]['dataset'],\n", + " grouper=train_grouper,\n", + " batch_size=config.batch_size,\n", + " **config.loader_kwargs)\n", + "\n", + " # Set fields\n", + " datasets[split]['split'] = split\n", + " datasets[split]['name'] = full_dataset.split_names[split]\n", + " datasets[split]['verbose'] = verbose\n", + " # Loggers\n", + " # Loggers\n", + " datasets[split]['eval_logger'] = BatchLogger(\n", + " os.path.join(config.log_dir, f'{split}_eval.csv'), mode=mode, use_wandb=(config.use_wandb and verbose))\n", + " datasets[split]['algo_logger'] = BatchLogger(\n", + " os.path.join(config.log_dir, f'{split}_algo.csv'), mode=mode, use_wandb=(config.use_wandb and verbose))\n", + "\n", + " if config.use_wandb:\n", + " initialize_wandb(config)\n", + "\n", + "# Logging dataset info\n", + "if config.no_group_logging and full_dataset.is_classification and full_dataset.y_size==1:\n", + " log_grouper = CombinatorialGrouper(\n", + " dataset=full_dataset,\n", + " groupby_fields=['y'])\n", + "elif config.no_group_logging:\n", + " log_grouper = None\n", + "else:\n", + " log_grouper = train_grouper\n", + "log_group_data(datasets, log_grouper, logger)\n", + "\n", + "## Initialize algorithm\n", + "algorithm = initialize_algorithm(\n", + " config=config,\n", + " datasets=datasets,\n", + " train_grouper=train_grouper)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']\n", + "val_celltype = ['A549']\n", + "test_celltype = ['GM12878']\n", + "all_celltypes = train_celltypes + val_celltype + test_celltype\n", + "\n", + "metadata_map = {}\n", + "metadata_map['chr'] = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX']\n", + "metadata_map['celltype'] = all_celltypes\n", + "\n", + "_split_dict = {\n", + " 'train': 0,\n", + " 'val-id': 1,\n", + " 'test': 2,\n", + " 'val-ood': 3\n", + "}\n", + "_split_names = {\n", + " 'train': 'Train',\n", + " 'val-id': 'Validation (ID)',\n", + " 'test': 'Test',\n", + " 'val-ood': 'Validation (OOD)'\n", + "}\n", + "_split_scheme = 'standard'" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "('H1-hESC', 25.299736976623535)\n", + "('HCT116', 49.68733310699463)\n", + "('HeLa-S3', 74.65905213356018)\n", + "('HepG2', 99.33112812042236)\n", + "('K562', 124.1327919960022)\n", + "('A549', 149.19999814033508)\n", + "('GM12878', 174.0277030467987)\n" + ] + } + ], + "source": [ + "itime = time.time()\n", + "sequence_filename = os.path.join(data_dir, 'sequence.npz')\n", + "seq_arr = np.load(sequence_filename)\n", + "print(time.time() - itime)\n", + "\n", + "itime = time.time()\n", + "_seq_bp = {}\n", + "for chrom in seq_arr:\n", + " _seq_bp[chrom] = seq_arr[chrom]\n", + " print(chrom, time.time() - itime)\n", + "itime = time.time()\n", + "_dnase_allcelltypes = {}\n", + "for ct in all_celltypes:\n", + " dnase_filename = os.path.join(data_dir, '{}_dnase.npz'.format(ct))\n", + " dnase_npz_file = np.load(dnase_filename)\n", + " _dnase_allcelltypes[ct] = {}\n", + " for chrom in _seq_bp:\n", + " _dnase_allcelltypes[ct][chrom] = dnase_npz_file[chrom]\n", + " print(ct, time.time() - itime)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Train/eval" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if not config.eval_only:\n", + " ## Load saved results if resuming\n", + " resume_success = False\n", + " if resume:\n", + " save_path = os.path.join(config.log_dir, 'last_model.pth')\n", + " if not os.path.exists(save_path):\n", + " epochs = [\n", + " int(file.split('_')[0])\n", + " for file in os.listdir(config.log_dir) if file.endswith('.pth')]\n", + " if len(epochs) > 0:\n", + " latest_epoch = max(epochs)\n", + " save_path = os.path.join(config.log_dir, f'{latest_epoch}_model.pth')\n", + " try:\n", + " prev_epoch, best_val_metric = load(algorithm, save_path)\n", + " epoch_offset = prev_epoch + 1\n", + " logger.write(f'Resuming from epoch {epoch_offset} with best val metric {best_val_metric}')\n", + " resume_success = True\n", + " except FileNotFoundError:\n", + " pass\n", + "\n", + " if resume_success == False:\n", + " epoch_offset=0\n", + " best_val_metric=None\n", + "\n", + "\n", + " train(\n", + " algorithm=algorithm,\n", + " datasets=datasets,\n", + " general_logger=logger,\n", + " config=config,\n", + " epoch_offset=epoch_offset,\n", + " best_val_metric=best_val_metric)\n", + "else:\n", + " if config.eval_epoch is None:\n", + " eval_model_path = os.path.join(config.log_dir, 'best_model.pth')\n", + " else:\n", + " eval_model_path = os.path.join(config.log_dir, f'{config.eval_epoch}_model.pth')\n", + " best_epoch, best_val_metric = load(algorithm, eval_model_path)\n", + " if config.eval_epoch is None:\n", + " epoch = best_epoch\n", + " else:\n", + " epoch = config.eval_epoch\n", + " evaluate(\n", + " algorithm=algorithm,\n", + " datasets=datasets,\n", + " epoch=epoch,\n", + " general_logger=logger,\n", + " config=config)\n", + "\n", + "logger.close()\n", + "for split in datasets:\n", + " datasets[split]['eval_logger'].close()\n", + " datasets[split]['algo_logger'].close()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 78, + "metadata": {}, + "outputs": [], + "source": [ + "import math\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "class Beagle(nn.Module):\n", + " \"\"\"\n", + " Neural net models over genomic sequence.\n", + " Input:\n", + " - sequence_length: int (default 1000) \n", + " - Shape: (N, 5, sequence_length, 1) with batch size N.\n", + " \n", + " Output:\n", + " - prediction (Tensor): float torch tensor of shape (N, )\n", + " \n", + " TODO: Finish docstring.\n", + " \"\"\"\n", + " def __init__(self):\n", + " \"\"\"\n", + " Parameters\n", + " ----------\n", + " sequence_length : int\n", + " n_genomic_features : int\n", + " \"\"\"\n", + " super(Beagle, self).__init__()\n", + "\n", + " self.dropout = 0.3\n", + " self.num_cell_types = 1\n", + " self.conv1 = nn.Conv2d(5, 300, (19, 1), stride = (1, 1), padding=(9,0))\n", + " self.conv2 = nn.Conv2d(300, 200, (11, 1), stride = (1, 1), padding = (5,0))\n", + " self.conv3 = nn.Conv2d(200, 200, (7, 1), stride = (1, 1), padding = (4,0))\n", + " self.bn1 = nn.BatchNorm2d(300)\n", + " self.bn2 = nn.BatchNorm2d(200)\n", + " self.bn3 = nn.BatchNorm2d(200)\n", + " self.maxpool1 = nn.MaxPool2d((3, 1))\n", + " self.maxpool2 = nn.MaxPool2d((4, 1))\n", + " self.maxpool3 = nn.MaxPool2d((4, 1))\n", + "\n", + " self.fc1 = nn.Linear(4200, 1000)\n", + " self.bn4 = nn.BatchNorm1d(1000)\n", + "\n", + " self.fc2 = nn.Linear(1000, 1000)\n", + " self.bn5 = nn.BatchNorm1d(1000)\n", + "\n", + " self.fc3 = nn.Linear(1000, self.num_cell_types)\n", + "\n", + " def forward(self, s):\n", + " s = s.permute(0, 2, 1).contiguous() # batch_size x 5 x 1000\n", + " s = s.view(-1, 5, 1000, 1) # batch_size x 5 x 1000 x 1 [5 channels]\n", + " s = self.maxpool1(F.relu(self.bn1(self.conv1(s)))) # batch_size x 300 x 333 x 1\n", + " s = self.maxpool2(F.relu(self.bn2(self.conv2(s)))) # batch_size x 200 x 83 x 1\n", + " s = self.maxpool3(F.relu(self.bn3(self.conv3(s)))) # batch_size x 200 x 21 x 1\n", + " s = s.view(-1, 4200)\n", + " conv_out = s\n", + "\n", + " s = F.dropout(F.relu(self.bn4(self.fc1(s))), p=self.dropout, training=self.training) # batch_size x 1000\n", + " s = F.dropout(F.relu(self.bn5(self.fc2(s))), p=self.dropout, training=self.training) # batch_size x 1000\n", + " \n", + " s = self.fc3(s)\n", + "\n", + " return s, conv_out" + ] + }, + { + "cell_type": "code", + "execution_count": 86, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[('nnet.0.weight', 33280),\n", + " ('nnet.0.bias', 320),\n", + " ('bdlstm.0.weight_ih_l0', 409600),\n", + " ('bdlstm.0.weight_hh_l0', 409600),\n", + " ('bdlstm.0.bias_ih_l0', 1280),\n", + " ('bdlstm.0.bias_hh_l0', 1280),\n", + " ('bdlstm.0.weight_ih_l0_reverse', 409600),\n", + " ('bdlstm.0.weight_hh_l0_reverse', 409600),\n", + " ('bdlstm.0.bias_ih_l0_reverse', 1280),\n", + " ('bdlstm.0.bias_hh_l0_reverse', 1280),\n", + " ('classifier.1.weight', 592000),\n", + " ('classifier.1.bias', 925),\n", + " ('classifier.3.weight', 4625),\n", + " ('classifier.3.bias', 5)]" + ] + }, + "execution_count": 86, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def count_parameters(model):\n", + " return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", + "\n", + "model = Beagle2()\n", + "model = DanQ(50, 5)\n", + "\n", + "lst = [(x[0], x[1].numel()) for x in model.named_parameters()]\n", + "#np.sum([x[1] for x in lst])\n", + "count_parameters(model)\n", + "lst" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index 327ffb3e..f8b66f25 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -33,25 +33,43 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): self._y_size = 1 self._n_classes = 2 - # self._tr_chrs = ['chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX'] - self._tr_chrs = ['chr2', 'chr9', 'chr11'] - self._te_chrs = ['chr1', 'chr8', 'chr21'] + # self._train_chroms = ['chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX'] + self._train_chroms = ['chr2', 'chr9', 'chr11'] + self._test_chroms = ['chr1', 'chr8', 'chr21'] self._transcription_factor = 'MAX' self._train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562'] self._val_celltype = ['A549'] self._test_celltype = ['GM12878'] + self._all_chroms = self._train_chroms + self._test_chroms self._all_celltypes = self._train_celltypes + self._val_celltype + self._test_celltype - self._metadata_fields = ['chr', 'celltype', 'y'] self._metadata_map = {} - self._metadata_map['chr'] = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX'] + self._metadata_map['chr'] = self._all_chroms #['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX'] self._metadata_map['celltype'] = self._all_celltypes + # Get the splits + if split_scheme=='official': + split_scheme = 'standard' + + self._split_scheme = split_scheme + self._split_dict = { + 'train': 0, + 'id_val': 1, + 'test': 2, + 'val': 3 + } + self._split_names = { + 'train': 'Train', + 'id_val': 'Validation (ID)', + 'test': 'Test', + 'val': 'Validation (OOD)', + } + # Load sequence and DNase features sequence_filename = os.path.join(self._data_dir, 'sequence.npz') seq_arr = np.load(sequence_filename) self._seq_bp = {} - for chrom in seq_arr: + for chrom in self._all_chroms: #seq_arr: self._seq_bp[chrom] = seq_arr[chrom] print(chrom, time.time() - itime) @@ -60,29 +78,27 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): dnase_filename = os.path.join(self._data_dir, '{}_dnase.npz'.format(ct)) dnase_npz_contents = np.load(dnase_filename) self._dnase_allcelltypes[ct] = {} - for chrom in self._seq_bp: + for chrom in self._all_chroms: #self._seq_bp: self._dnase_allcelltypes[ct][chrom] = dnase_npz_contents[chrom] print(ct, time.time() - itime) # Read in metadata dataframe from training+validation data train_regions_labeled = pd.read_csv(os.path.join(self._data_dir, 'labels/{}.train.labels.tsv.gz'.format(self._transcription_factor)), sep='\t') val_regions_labeled = pd.read_csv(os.path.join(self._data_dir, 'labels/{}.val.labels.tsv.gz'.format(self._transcription_factor)), sep='\t') - training_df = train_regions_labeled[np.isin(train_regions_labeled['chr'], self._tr_chrs)] - val_df = val_regions_labeled[np.isin(val_regions_labeled['chr'], self._te_chrs)] + training_df = train_regions_labeled[np.isin(train_regions_labeled['chr'], self._train_chroms)] + val_df = val_regions_labeled[np.isin(val_regions_labeled['chr'], self._test_chroms)] all_df = pd.concat([training_df, val_df]) - print(train_regions_labeled, time.time() - itime) - # Filter by start/stop coordinate if needed - # (TODO: remove for final version) + # Filter by start/stop coordinate if needed (TODO: remove for final version) filter_msk = all_df['start'] >= 0 filter_msk = all_df['start']%1000 == 0 all_df = all_df[filter_msk] pd_list = [] - for ct in self._train_celltypes: + for ct in self._all_celltypes: tc_chr = all_df[['chr', 'start', 'stop', ct]] tc_chr.columns = ['chr', 'start', 'stop', 'y'] - tc_chr['celltype'] = ct + tc_chr.insert(len(tc_chr.columns), 'celltype', ct) pd_list.append(tc_chr) metadata_df = pd.concat(pd_list) @@ -91,33 +107,9 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): non_ambig_mask = (y_array != -1) metadata_df['y'] = y_array self._metadata_df = metadata_df[non_ambig_mask] - self._y_array = torch.LongTensor(y_array[non_ambig_mask]) - chr_ints = self._metadata_df['chr'].replace(dict( [(y, x) for x, y in enumerate(self._metadata_map['chr'])] )).values - celltype_ints = self._metadata_df['celltype'].replace(dict( [(y, x) for x, y in enumerate(self._metadata_map['celltype'])] )).values - self._metadata_array = torch.stack( - (torch.LongTensor(chr_ints), - torch.LongTensor(celltype_ints), - self._y_array), - dim=1) - - # Get the splits - # TODO Extract splits as encoded in split_scheme. Hardcoded here for now. - self._split_scheme = split_scheme - self._split_dict = { - 'train': 0, - 'val-id': 1, - 'test': 2, - 'val-ood': 3 - } - self._split_names = { - 'train': 'Train', - 'val-id': 'Validation (ID)', - 'test': 'Test', - 'val-ood': 'Validation (OOD)', - } - train_regions_mask = np.isin(self._metadata_df['chr'], self._tr_chrs) - val_regions_mask = np.isin(self._metadata_df['chr'], self._te_chrs) + train_regions_mask = np.isin(self._metadata_df['chr'], self._train_chroms) + val_regions_mask = np.isin(self._metadata_df['chr'], self._test_chroms) train_celltype_mask = np.isin(self._metadata_df['celltype'], self._train_celltypes) val_celltype_mask = np.isin(self._metadata_df['celltype'], self._val_celltype) test_celltype_mask = np.isin(self._metadata_df['celltype'], self._test_celltype) @@ -125,17 +117,33 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): split_array = -1*np.ones(self._metadata_df.shape[0]).astype(int) split_array[np.logical_and(train_regions_mask, train_celltype_mask)] = self._split_dict['train'] split_array[np.logical_and(val_regions_mask, test_celltype_mask)] = self._split_dict['test'] - # Validate using test chr, either using a designated validation cell line ('val-ood') or a training cell line ('val-id') - split_array[np.logical_and(val_regions_mask, val_celltype_mask)] = self._split_dict['val-ood'] - split_array[np.logical_and(val_regions_mask, train_celltype_mask)] = self._split_dict['val-id'] + # Validate using test chr, either using a designated validation cell line ('val') or a training cell line ('id_val') + split_array[np.logical_and(val_regions_mask, val_celltype_mask)] = self._split_dict['val'] + split_array[np.logical_and(val_regions_mask, train_celltype_mask)] = self._split_dict['id_val'] + if self._split_scheme=='standard': - self._metadata_df['split'] = split_array - self._split_array = split_array + self._metadata_df.insert(len(self._metadata_df.columns), 'split', split_array) else: raise ValueError(f'Split scheme {self._split_scheme} not recognized') + + self._metadata_df = self._metadata_df[self._metadata_df['split'] != -1] + self._split_array = self._metadata_df['split'].values + + chr_ints = self._metadata_df['chr'].replace(dict( [(y, x) for x, y in enumerate(self._metadata_map['chr'])] )).values + celltype_ints = self._metadata_df['celltype'].replace(dict( [(y, x) for x, y in enumerate(self._metadata_map['celltype'])] )).values + self._y_array = torch.LongTensor(np.array(self._metadata_df['y'])) + + self._metadata_array = torch.stack( + (torch.LongTensor(chr_ints), + torch.LongTensor(celltype_ints), + self._y_array), + dim=1) + self._metadata_fields = ['chr', 'celltype', 'y'] + self._eval_grouper = CombinatorialGrouper( dataset=self, groupby_fields=['celltype']) + self._metric = Accuracy() super().__init__(root_dir, download, split_scheme) @@ -152,8 +160,8 @@ def get_input(self, idx): flank_size = 400 interval_start = this_metadata['start'] - flank_size interval_end = this_metadata['stop'] + flank_size - dnase_this = _dnase_allcelltypes[this_metadata['celltype']][this_metadata['chr']][interval_start:interval_end] - seq_this = _seq_bp[this_metadata['chr']][interval_start:interval_end] + dnase_this = self._dnase_allcelltypes[this_metadata['celltype']][this_metadata['chr']][interval_start:interval_end] + seq_this = self._seq_bp[this_metadata['chr']][interval_start:interval_end] return torch.tensor(np.column_stack([seq_this, dnase_this])) def eval(self, y_pred, y_true, metadata): From 33b6756e5ef1b8ed011688b8bab3026b462a25c4 Mon Sep 17 00:00:00 2001 From: aikanor Date: Thu, 11 Feb 2021 09:15:50 -0800 Subject: [PATCH 069/244] integration 16/ (with most remaining bugfixes) --- examples/configs/datasets.py | 4 +- examples/models/CNN_genome.py | 1 + examples/sbox_run_expt.ipynb | 1465 +++++--------------------- wilds/datasets/encodetfbs_dataset.py | 14 + 4 files changed, 269 insertions(+), 1215 deletions(-) diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index 727d0e92..4d941653 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -112,13 +112,13 @@ 'train_transform': None, 'eval_transform': None, 'loss_function': 'cross_entropy', - 'groupby_fields': ['celltype'], + 'groupby_fields': ['celltype', 'y'], 'val_metric': 'acc_avg', 'val_metric_decreasing': False, 'optimizer': 'Adam', # 'optimizer_kwargs': { }, 'scheduler': None, - 'batch_size': 128, + 'batch_size': 64, 'lr': 0.001, 'weight_decay': 0.01, 'n_epochs': 1, diff --git a/examples/models/CNN_genome.py b/examples/models/CNN_genome.py index cc464ef0..1c65b567 100644 --- a/examples/models/CNN_genome.py +++ b/examples/models/CNN_genome.py @@ -7,6 +7,7 @@ class Beagle(nn.Module): """ Neural net models over genomic sequence. Adapted from https://github.com/kundajelab/ChromDragoNN + Input: - s (Tensor): float torch tensor of shape (N, 5, 1000, 1) with batch size N. diff --git a/examples/sbox_run_expt.ipynb b/examples/sbox_run_expt.ipynb index 56a9f6a2..e50f790b 100644 --- a/examples/sbox_run_expt.ipynb +++ b/examples/sbox_run_expt.ipynb @@ -11,7 +11,28 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "ename": "ModuleNotFoundError", + "evalue": "No module named 'psutil'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpsutil\u001b[0m\u001b[0;34m;\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpsutil\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mProcess\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgetpid\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmemory_info\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrss\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0;36m1024\u001b[0m \u001b[0;34m**\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'psutil'" + ] + } + ], + "source": [ + "import os, psutil; print(psutil.Process(os.getpid()).memory_info().rss / 1024 ** 2)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -38,16 +59,16 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" + "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" ] }, - "execution_count": 2, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -142,7 +163,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -154,19 +175,13 @@ "config_encode = parser.parse_args(argstr_encode.split())\n", "config_encode = populate_defaults(config_encode)\n", "\n", - "config = config_camelyon\n", - "#config = config_encode" + "config = config_camelyon" ] }, { "cell_type": "code", - "execution_count": 6, - "metadata": { - "collapsed": true, - "jupyter": { - "outputs_hidden": true - } - }, + "execution_count": 5, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -275,6 +290,19 @@ " dataset=full_dataset)" ] }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "import copy\n", + "full_dataset_camelyon17 = copy.deepcopy(full_dataset)\n", + "\n", + "# supported.datasets[config_encode.dataset]\n", + "# print(config_camelyon.train_transform, config_encode.train_transform)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -284,11 +312,10 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 30, "metadata": { - "collapsed": true, "jupyter": { - "outputs_hidden": true + "source_hidden": true } }, "outputs": [ @@ -296,46 +323,19 @@ "name": "stdout", "output_type": "stream", "text": [ - "chr2 3.817250967025757\n", - "chr9 6.033524990081787\n", - "chr11 8.150986433029175\n", - "chr1 12.036555290222168\n", - "chr8 14.306443929672241\n", - "chr21 15.043241739273071\n", - "H1-hESC 21.61008930206299\n", - "HCT116 28.000329971313477\n", - "HeLa-S3 34.6184778213501\n", - "HepG2 41.089255809783936\n", - "K562 47.70136523246765\n", - "A549 54.22390341758728\n", - "GM12878 60.65142226219177\n", - " chr start stop A549 GM12878 H1-hESC HCT116 HeLa-S3 \\\n", - "0 chr10 600 800 U U U U U \n", - "1 chr10 650 850 U U U U U \n", - "2 chr10 700 900 U U U U U \n", - "3 chr10 750 950 U U U U U \n", - "4 chr10 800 1000 U U U U U \n", - "... ... ... ... ... ... ... ... ... \n", - "51676731 chrX 155269750 155269950 U U U U U \n", - "51676732 chrX 155269800 155270000 U U U U U \n", - "51676733 chrX 155269850 155270050 U U U U U \n", - "51676734 chrX 155269900 155270100 U U U U U \n", - "51676735 chrX 155269950 155270150 U U U U U \n", - "\n", - " HepG2 K562 \n", - "0 U U \n", - "1 U U \n", - "2 U U \n", - "3 U U \n", - "4 U U \n", - "... ... ... \n", - "51676731 U U \n", - "51676732 U U \n", - "51676733 U U \n", - "51676734 U U \n", - "51676735 U U \n", - "\n", - "[51676736 rows x 10 columns] 130.07371044158936\n" + "chr2 3.657395362854004\n", + "chr9 5.770605564117432\n", + "chr11 7.801896095275879\n", + "chr1 11.56663990020752\n", + "chr8 13.764073133468628\n", + "chr21 14.483267068862915\n", + "H1-hESC 20.850953817367554\n", + "HCT116 27.05355429649353\n", + "HeLa-S3 33.51919412612915\n", + "HepG2 39.89570116996765\n", + "K562 46.36982774734497\n", + "A549 52.82617139816284\n", + "GM12878 59.167165994644165\n" ] } ], @@ -417,9 +417,9 @@ "all_df = pd.concat([training_df, val_df])\n", "\n", "# Filter by start/stop coordinate if needed (TODO: remove for final version)\n", - "filter_msk = all_df['start'] >= 0\n", - "filter_msk = all_df['start']%1000 == 0\n", - "all_df = all_df[filter_msk]\n", + "# filter_msk = all_df['start'] >= 0\n", + "# filter_msk = all_df['start']%1000 == 0\n", + "# all_df = all_df[filter_msk]\n", "\n", "pd_list = []\n", "for ct in _all_celltypes:\n", @@ -427,20 +427,39 @@ " tc_chr.columns = ['chr', 'start', 'stop', 'y']\n", " tc_chr.insert(len(tc_chr.columns), 'celltype', ct)\n", " pd_list.append(tc_chr)\n", - "metadata_df = pd.concat(pd_list)" + "metadata_df = pd.concat(pd_list)\n", + "\n", + "# Get the y values, and remove ambiguous labels by default.\n", + "y_array = metadata_df['y'].replace({'U': 0, 'B': 1, 'A': -1}).values\n", + "non_ambig_mask = (y_array != -1)\n", + "metadata_df['y'] = y_array\n", + "_metadata_df = metadata_df[non_ambig_mask]" ] }, { "cell_type": "code", - "execution_count": 131, - "metadata": {}, + "execution_count": 35, + "metadata": { + "jupyter": { + "source_hidden": true + } + }, "outputs": [], "source": [ - "# Get the y values, and remove ambiguous labels by default.\n", - "y_array = metadata_df['y'].replace({'U': 0, 'B': 1, 'A': -1}).values\n", - "non_ambig_mask = (y_array != -1)\n", - "metadata_df['y'] = y_array\n", - "_metadata_df = metadata_df[non_ambig_mask]\n", + "samp_ndces = []\n", + "itime = time.time()\n", + "for ct in _all_celltypes:\n", + " neg_msk = np.logical_and((_metadata_df['celltype'] == ct), (_metadata_df['y'] == 0))\n", + " pos_msk = np.logical_and((_metadata_df['celltype'] == ct), (_metadata_df['y'] == 1))\n", + " neg_ndces = np.where(neg_msk)[0]\n", + " pos_ndces = np.where(pos_msk)[0]\n", + " np.random.seed(42)\n", + " samp_neg_ndces = np.random.choice(neg_ndces, size=len(pos_ndces), replace=False)\n", + " samp_ndces.extend(samp_neg_ndces)\n", + " samp_ndces.extend(pos_ndces)\n", + " print(ct, time.time() - itime)\n", + "\n", + "_metadata_df = _metadata_df.iloc[samp_ndces, :]\n", "\n", "train_regions_mask = np.isin(_metadata_df['chr'], _train_chroms)\n", "val_regions_mask = np.isin(_metadata_df['chr'], _test_chroms)\n", @@ -472,7 +491,7 @@ " torch.LongTensor(celltype_ints), \n", " _y_array),\n", " dim=1)\n", - "_metadata_fields = ['chr', 'celltype', 'y']\n" + "_metadata_fields = ['chr', 'celltype', 'y']" ] }, { @@ -484,7 +503,7 @@ }, { "cell_type": "code", - "execution_count": 138, + "execution_count": 23, "metadata": {}, "outputs": [], "source": [ @@ -523,8 +542,8 @@ " self._y_size = 1\n", " self._n_classes = 2\n", " \n", - " # self._train_chroms = ['chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX']\n", - " self._train_chroms = ['chr2', 'chr9', 'chr11']\n", + " self._train_chroms = ['chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX']\n", + " # self._train_chroms = ['chr2', 'chr9', 'chr11']\n", " self._test_chroms = ['chr1', 'chr8', 'chr21']\n", " self._transcription_factor = 'MAX'\n", " self._train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']\n", @@ -580,9 +599,9 @@ " all_df = pd.concat([training_df, val_df])\n", " \n", " # Filter by start/stop coordinate if needed (TODO: remove for final version)\n", - " filter_msk = all_df['start'] >= 0\n", - " filter_msk = all_df['start']%1000 == 0\n", - " all_df = all_df[filter_msk]\n", + " # filter_msk = all_df['start'] >= 0\n", + " # filter_msk = all_df['start']%1000 == 0\n", + " # all_df = all_df[filter_msk]\n", " \n", " pd_list = []\n", " for ct in self._all_celltypes:\n", @@ -598,6 +617,20 @@ " metadata_df['y'] = y_array\n", " self._metadata_df = metadata_df[non_ambig_mask]\n", " \n", + " samp_ndces = []\n", + " itime = time.time()\n", + " for ct in self._all_celltypes:\n", + " neg_msk = np.logical_and((self._metadata_df['celltype'] == ct), (self._metadata_df['y'] == 0))\n", + " pos_msk = np.logical_and((self._metadata_df['celltype'] == ct), (self._metadata_df['y'] == 1))\n", + " neg_ndces = np.where(neg_msk)[0]\n", + " pos_ndces = np.where(pos_msk)[0]\n", + " np.random.seed(42)\n", + " samp_neg_ndces = np.random.choice(neg_ndces, size=len(pos_ndces), replace=False)\n", + " samp_ndces.extend(samp_neg_ndces)\n", + " samp_ndces.extend(pos_ndces)\n", + " print(ct, time.time() - itime)\n", + " self._metadata_df = self._metadata_df.iloc[samp_ndces, :]\n", + " \n", " train_regions_mask = np.isin(self._metadata_df['chr'], self._train_chroms)\n", " val_regions_mask = np.isin(self._metadata_df['chr'], self._test_chroms)\n", " train_celltype_mask = np.isin(self._metadata_df['celltype'], self._train_celltypes)\n", @@ -663,26 +696,43 @@ }, { "cell_type": "code", - "execution_count": 139, + "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "chr2 3.836134910583496\n", - "chr9 6.034452438354492\n", - "chr11 8.16244888305664\n", - "chr1 12.12421178817749\n", - "chr8 14.44963550567627\n", - "chr21 15.212148189544678\n", - "H1-hESC 21.892271518707275\n", - "HCT116 28.37229895591736\n", - "HeLa-S3 35.18828296661377\n", - "HepG2 41.83891773223877\n", - "K562 48.590251445770264\n", - "A549 55.3311812877655\n", - "GM12878 61.93817687034607\n" + "chr2 3.718320846557617\n", + "chr3 6.73882269859314\n", + "chr4 9.651247501373291\n", + "chr5 12.439628839492798\n", + "chr6 15.05026388168335\n", + "chr7 17.475954055786133\n", + "chr9 19.6206693649292\n", + "chr10 21.68758535385132\n", + "chr11 23.74817419052124\n", + "chr12 25.81403160095215\n", + "chr13 27.559557676315308\n", + "chr14 29.18643832206726\n", + "chr15 30.739391565322876\n", + "chr16 32.11144256591797\n", + "chr17 33.348127126693726\n", + "chr18 34.53834342956543\n", + "chr19 35.434733629226685\n", + "chr20 36.399296283721924\n", + "chr22 37.1924102306366\n", + "chrX 39.56284308433533\n", + "chr1 43.3526566028595\n", + "chr8 45.583492040634155\n", + "chr21 46.311339378356934\n", + "H1-hESC 66.45292735099792\n", + "HCT116 86.06067085266113\n", + "HeLa-S3 106.47142815589905\n", + "HepG2 126.59437656402588\n", + "K562 146.93650436401367\n", + "A549 167.19306707382202\n", + "GM12878 187.4349775314331\n" ] } ], @@ -696,1068 +746,63 @@ }, { "cell_type": "code", - "execution_count": 140, + "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "<__main__.EncodeTFBSDataset at 0x7fe6b69d33a0>" + "(array(['A549', 'GM12878', 'H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562'],\n", + " dtype=object),\n", + " array([ 5118, 1702, 8460, 12806, 8348, 11774, 12518]))" ] }, - "execution_count": 140, + "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "full_dataset_encode" + "np.unique(full_dataset_encode._metadata_df['celltype'], return_counts=True)" ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 17, "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "(array([0, 1]), array([227977, 227977]))" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([0, 0, 0, ..., 1, 1, 1]) torch.Size([60726])\n", + "(array([0, 1]), array([30426, 30300]))\n", + "(array([0, 1, 2, 3]), array([28556, 25350, 1702, 5118]))\n" + ] } ], - "source": [] + "source": [ + "full_dataset = copy.deepcopy(full_dataset_encode)\n", + "print(full_dataset._y_array, full_dataset._y_array.shape)\n", + "print(np.unique(full_dataset.y_array.numpy(), return_counts=True))\n", + "print(np.unique(full_dataset._metadata_df['split'], return_counts=True))\n", + "\n", + "#full_dataset._input_array" + ] }, { "cell_type": "code", - "execution_count": 17, - "metadata": { - "collapsed": true, - "jupyter": { - "outputs_hidden": true - } - }, + "execution_count": 9, + "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "tensor([1, 1, 1, ..., 0, 0, 0]) torch.Size([455954])\n" + "tensor([0, 0, 0, ..., 0, 0, 0]) torch.Size([5568233])\n", + "(array([0, 1]), array([5537933, 30300]))\n", + "(array([0, 1, 2, 3]), array([2533595, 2163528, 437124, 433986]))\n" ] - }, - { - "data": { - "text/plain": [ - "['patches/patient_004_node_4/patch_patient_004_node_4_x_3328_y_21792.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3200_y_22272.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3168_y_22272.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3328_y_21760.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3232_y_22240.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3168_y_22240.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3136_y_22208.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_2656_y_18880.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3136_y_22240.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3296_y_21856.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3296_y_21792.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3360_y_21824.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3360_y_21760.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3296_y_21824.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3328_y_21824.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_2688_y_18912.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3168_y_22176.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_2688_y_18816.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3200_y_22176.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3168_y_22208.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_2688_y_18880.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3296_y_21760.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_2656_y_18848.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3136_y_22272.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3264_y_21856.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3264_y_21824.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_2688_y_18848.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3264_y_21792.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_2656_y_18944.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3200_y_22208.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3232_y_22208.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3200_y_22240.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_2688_y_18944.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_3360_y_21792.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_2656_y_18912.png',\n", - " 'patches/patient_004_node_4/patch_patient_004_node_4_x_2656_y_18816.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12960_y_36128.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_35968.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_36512.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_36064.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13312_y_36320.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_35968.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12128_y_34560.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_36384.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_36192.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13120_y_35936.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_36480.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_35680.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_35648.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_36032.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_36416.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12992_y_36160.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12160_y_34752.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_36288.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_36160.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_35744.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12672_y_35840.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_36320.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12896_y_35648.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_35904.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12864_y_36512.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13120_y_35904.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_35872.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_36192.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_36224.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_36416.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_35872.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13152_y_35872.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_35904.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_36224.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12672_y_36192.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_35968.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_35680.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_35648.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_36096.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_35616.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_36224.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_35776.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_36064.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_35936.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_35808.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_36224.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12672_y_36320.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_36032.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_36128.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12128_y_34752.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_36288.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_35904.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_36160.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_36192.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13184_y_35968.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_35808.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13184_y_35936.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12800_y_36192.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12096_y_34560.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12096_y_34528.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13120_y_36064.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_36032.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12128_y_34784.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12672_y_36288.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_36224.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12672_y_36256.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_35584.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_35616.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12096_y_34720.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12992_y_36352.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13024_y_36160.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12128_y_34624.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_36416.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_36384.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12640_y_36352.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_35712.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12960_y_36096.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12960_y_36352.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_36384.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13024_y_36352.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12800_y_36448.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_36448.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12640_y_36288.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_36128.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12064_y_34560.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_35808.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12096_y_34592.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12960_y_36448.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12864_y_35616.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12672_y_36224.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_36000.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12992_y_35840.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_36096.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12992_y_36416.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12864_y_36352.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12064_y_34720.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12448_y_36064.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_35648.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_36352.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12448_y_36096.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13120_y_35872.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_36192.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_36256.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_35840.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_36128.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_35872.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_36064.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12992_y_36032.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12864_y_36320.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12640_y_36224.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_35968.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12640_y_35808.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_36064.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12896_y_36224.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_35744.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12992_y_36128.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_36352.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12992_y_36096.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13312_y_36224.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_35872.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12096_y_34624.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_35872.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12992_y_36384.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_35936.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12448_y_35936.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_36192.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12128_y_34720.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_36096.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_36352.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_36128.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12448_y_35488.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12800_y_36512.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13376_y_36192.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_36096.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_36160.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13376_y_36128.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13024_y_36064.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12864_y_36256.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12448_y_36128.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12800_y_36480.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_36480.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12864_y_36480.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12800_y_36544.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_36032.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13024_y_36384.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12672_y_36096.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13024_y_36320.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_36224.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12896_y_36256.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13344_y_36288.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12128_y_34656.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12864_y_36544.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_36064.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_35936.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_35776.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12992_y_36064.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13120_y_36000.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_35744.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12960_y_36416.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12864_y_36224.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_36192.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12640_y_35840.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12640_y_35872.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_36160.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12160_y_34720.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_36192.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13024_y_35872.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_36096.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_36480.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12128_y_34592.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_35936.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_36096.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13376_y_36160.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12864_y_36288.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_35872.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12064_y_34688.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_36160.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12896_y_35616.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_36096.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_35616.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_35840.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_36000.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_36032.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13024_y_35840.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_35808.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_35904.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13024_y_36032.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12096_y_34752.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_36032.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13120_y_35968.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12064_y_34624.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12864_y_35680.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12672_y_35808.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12160_y_34688.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12448_y_35968.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_36000.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12800_y_36416.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12800_y_36288.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_36352.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13024_y_36096.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12960_y_36064.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12128_y_34688.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13152_y_35936.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13152_y_36000.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12800_y_36224.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13344_y_36192.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_36096.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_35744.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12064_y_34592.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12096_y_34688.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_35840.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12992_y_36320.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_35584.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12096_y_34656.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13024_y_36416.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_35680.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_36000.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_35936.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12096_y_34784.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_35936.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_35648.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12864_y_35712.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_36416.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_36256.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_36416.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_36192.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12640_y_36320.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12896_y_36288.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_36064.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_36128.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13024_y_36128.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_36096.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12672_y_36128.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12800_y_36384.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12160_y_34656.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12992_y_36448.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12800_y_36352.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_36000.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_36320.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_35840.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12064_y_34528.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_35936.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_36416.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_36320.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12992_y_35872.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_36256.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_36544.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13120_y_36032.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_36224.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_35904.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_36064.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_36160.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_36256.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_35488.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_36448.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_36288.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_35936.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_36096.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_36256.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_36128.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13280_y_36288.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_35968.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_35680.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_36160.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12864_y_35648.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12960_y_36384.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_36288.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_36128.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12768_y_36512.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12672_y_36160.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13344_y_36224.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_36128.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_36128.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_36480.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_36224.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12640_y_36256.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13184_y_35904.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_36256.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12576_y_35776.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12544_y_36160.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_35840.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12640_y_35904.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_35968.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_35968.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13344_y_36160.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_35968.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12608_y_35872.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13088_y_36032.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13312_y_36288.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12480_y_35776.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12512_y_35808.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13184_y_35872.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12672_y_35904.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13152_y_35968.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12448_y_35904.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12800_y_36256.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12736_y_36288.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_36448.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13024_y_36000.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12832_y_36448.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12448_y_36000.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12448_y_36032.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12672_y_35872.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12800_y_36320.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13056_y_36384.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_13152_y_35904.png',\n", - " 'patches/patient_009_node_1/patch_patient_009_node_1_x_12704_y_36256.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16416.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16544.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16384_y_24352.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18624_y_16544.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16288_y_24448.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_17216.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_15968.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_17056.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18400_y_16608.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18368_y_16032.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16544_y_24768.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18464_y_16288.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_16672.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_16672.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17024_y_24512.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18784_y_16224.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17664_y_15488.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_17440.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17536_y_15552.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16384.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_15904.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16320.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_17472.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16864.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18368_y_16192.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17632_y_15104.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_17024.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_16704.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_15648.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_16416.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_16832.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16800.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_17216.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_15616.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_16288.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_15680.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_16480.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_16480.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_16896.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16320_y_25056.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18720_y_16448.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17344_y_28160.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_16704.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_16672.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16480.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_16928.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17664_y_15520.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16480_y_24992.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18368_y_16608.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_17536.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18368_y_16640.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_17216.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_15872.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19264_y_16352.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17088_y_24448.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18432_y_16192.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16768.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18720_y_16160.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16480.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17088_y_24512.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19360_y_16736.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_17024.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16160.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18432_y_16352.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_17088.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16160.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19264_y_16256.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16896.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16544.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18784_y_16672.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_16416.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16320.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_16640.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_17152.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_17248.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19264_y_16640.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_17024.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16320_y_24448.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18784_y_16800.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_16448.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16768_y_25152.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16736_y_25152.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16704_y_24640.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16064.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17600_y_15488.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_17376.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16608_y_24512.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16864_y_24288.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17120_y_25088.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16544_y_24800.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18304_y_16064.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17056_y_24512.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16096.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16384.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17280_y_24768.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16288.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_17056.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16768_y_24928.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16448.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_16320.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_17024.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_17184.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16480.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16288_y_25024.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16352.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_15648.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16672.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18720_y_16608.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16320.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16352_y_25024.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_17184.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_17056.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16544.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_16480.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16288.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_15712.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18720_y_16736.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16576.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_17152.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18624_y_16480.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_17216.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16544.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16800.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18432_y_15744.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16096.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16256_y_24448.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16480_y_24832.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16608_y_25024.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_16832.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16352.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18624_y_16448.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_15680.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16480_y_24800.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16992_y_24448.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18784_y_16768.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_16352.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16224.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_17280.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16832_y_24288.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_17408.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16192.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19264_y_16960.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16608_y_24608.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16000.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_17088.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17248_y_24768.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_16640.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16832.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_17504.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_16800.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17632_y_15520.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16832.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_16768.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_16992.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_17248.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_17504.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16256_y_24480.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_16000.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_16576.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16320_y_24608.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16416.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17280_y_28064.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16032.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17536_y_25120.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_17152.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_16192.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17408_y_24896.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19296_y_16608.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_16096.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18432_y_16640.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19264_y_16992.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19264_y_16704.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16864.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_17056.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18624_y_17120.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16768.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16352_y_24512.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_15648.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_16896.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18432_y_16704.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_17088.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16256.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19360_y_16672.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16768_y_25184.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18720_y_16640.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_16288.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_16608.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_17184.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16512.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16576.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16160_y_24864.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_16032.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16608_y_24640.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16416.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16640_y_25440.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17280_y_28192.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16800_y_25120.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_16608.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_17344.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16224.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16384.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_16832.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_16736.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18368_y_16672.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16416.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16608_y_24352.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_16480.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18400_y_16736.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17600_y_15392.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16320_y_24480.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_15968.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16640_y_24576.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_16736.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16992_y_24960.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16672_y_24512.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16512.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_16640.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_17120.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18336_y_16064.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_16096.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16128_y_24896.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16512.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_17440.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_17248.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16544_y_25024.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16064.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_17440.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_17536.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17440_y_24928.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_17216.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16864_y_24352.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16224.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_16768.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_17184.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_16064.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18368_y_16000.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_17344.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18720_y_15680.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16672_y_24928.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18464_y_16672.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17536_y_15520.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16704_y_24608.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16576.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16672.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16736.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18464_y_16608.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16544.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18784_y_16416.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17312_y_28128.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_17312.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_17120.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16448_y_24768.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_16128.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16640_y_24640.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16480_y_25024.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_17152.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_17120.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_17280.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_15584.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_15648.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17408_y_24928.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16832.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17344_y_28128.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_17120.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16736.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16512.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_17184.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_17440.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_16864.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18400_y_16672.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18784_y_16960.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18784_y_15968.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16512.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16800.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_15680.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_17024.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18720_y_16672.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16672.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_15616.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18784_y_16192.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18336_y_16160.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_17056.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16800.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16864.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17024_y_24960.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16640.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16320_y_24640.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17664_y_15328.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_17408.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_15712.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16736_y_24608.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_16704.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_16512.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16928.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_17088.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_16480.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16704.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_16960.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16960.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16096.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16448_y_25024.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_16384.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16832.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17376_y_24896.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16736.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_15648.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_16064.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_15584.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16192.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_16544.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_16864.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16768_y_25120.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16608.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_17248.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16384_y_24640.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_17184.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16512.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_16576.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16736.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_17088.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_17376.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_17088.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_16544.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_16064.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_17056.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19264_y_16480.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16512_y_24352.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17568_y_15488.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18464_y_16128.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16992.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16704.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17664_y_15456.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16480_y_25056.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_17088.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16256.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19296_y_16832.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_16768.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16928.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_17344.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_17120.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16608.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_16768.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16352_y_24608.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17664_y_15104.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16320.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16800_y_24928.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16672_y_24960.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17696_y_15488.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_15744.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_17344.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16704_y_24960.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17184_y_25120.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_16064.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16992_y_24480.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_15776.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16448.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18784_y_15744.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_16672.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_16096.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19328_y_16800.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17856_y_15296.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_16448.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18720_y_15712.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_15904.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16384_y_24512.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16992.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16512.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_16640.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16096.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_16736.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16512_y_24800.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_17216.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_17056.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16608_y_25120.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_17280.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_16576.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16768.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16416_y_24320.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_17376.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16576_y_24768.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18720_y_17440.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17312_y_28032.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16448.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_16992.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17568_y_25216.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16320.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_17120.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_16256.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19264_y_16512.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_17024.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17056_y_24480.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16832_y_24864.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16704.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18368_y_16736.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16288_y_24416.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_16704.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16608_y_24672.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16192.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19264_y_16768.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18720_y_16256.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16960.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_17472.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16320.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_15776.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18784_y_17440.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16192.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_17216.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16544.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19264_y_16576.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16064.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_17376.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16512.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17152_y_25088.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_16608.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_16032.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_16448.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_16544.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16576.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_16768.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_16224.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_17312.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16608.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_17056.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_16448.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19264_y_16384.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18624_y_16416.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_16416.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16576_y_24608.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_16512.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_17248.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16896.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_17056.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17248_y_24800.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_16928.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16896.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16416.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16000.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16640_y_24384.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_16416.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17632_y_15072.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_16480.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_16800.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16608.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_17408.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_16512.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17600_y_15136.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_16800.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_15936.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17600_y_15072.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_16000.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16480.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18624_y_16640.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_16640.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_16000.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16864_y_24320.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_17024.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_16672.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16800_y_25152.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16608.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16896_y_24576.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16480.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_17024.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_16672.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_17440.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16096.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16384.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_16384.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16352.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16672.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17664_y_15200.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_16448.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17120_y_25152.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17568_y_15392.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16416_y_24352.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16480.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16512_y_24320.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16096.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_17120.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16480_y_24352.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_16448.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_16960.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_16960.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_17248.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16672.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16576.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_16640.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16576_y_25024.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16032.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_16704.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_15648.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16448_y_24384.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16608_y_25088.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17568_y_25152.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_17216.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18336_y_16640.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16928.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_16576.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17664_y_15424.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16992.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16288.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_16448.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17664_y_15072.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16992.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18624_y_17280.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18400_y_16352.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16352.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18816_y_16512.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17696_y_15232.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_17152.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_16640.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_16288.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17664_y_15136.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_16576.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18464_y_15712.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16576.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_15712.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_16704.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18432_y_15680.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19296_y_16544.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_17248.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16544_y_25056.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18624_y_16832.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16288_y_24480.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16384.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18624_y_17248.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16128.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_17440.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17696_y_15520.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16512_y_25024.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16064.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18432_y_16128.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_17184.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_16416.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17280_y_24800.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18432_y_16256.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18624_y_16096.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19328_y_16832.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16384_y_24320.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_16736.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17600_y_25216.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_17248.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18720_y_16576.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_17280.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_16768.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18688_y_16672.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16960.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16768.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16640_y_25408.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16736_y_24640.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16288_y_25056.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_17152.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_16224.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_17152.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18720_y_16480.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_17024.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16928.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16736_y_24576.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16640_y_25056.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16512_y_24768.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16640.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16736.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_15680.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16896_y_24320.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_16864.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16896.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18464_y_16448.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17056_y_24448.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16704.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_16160.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16672_y_25088.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16864.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17600_y_15360.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_16032.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18368_y_16704.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16608.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_15936.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18464_y_17088.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17632_y_15200.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_15584.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19328_y_16768.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16736.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_16448.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16800.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_16416.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16288_y_25088.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16480_y_24768.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19264_y_16832.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_17408.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16704.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18464_y_16512.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16736.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_17312.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16288_y_24512.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_17408.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_17024.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16640_y_24608.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16128.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16672_y_24640.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17152_y_25120.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17024_y_24928.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18592_y_16672.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17568_y_25120.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16384.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_16512.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_17184.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16960.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16608_y_24384.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19040_y_16352.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_16160.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17632_y_15488.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_15648.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_17056.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_15744.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16320_y_25024.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16736_y_24672.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17632_y_15424.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_16800.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18464_y_16064.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16608_y_24992.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_17120.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16608.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16640.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19296_y_16736.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_16896.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18304_y_16032.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17632_y_15136.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18976_y_17056.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19136_y_16128.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18656_y_16032.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18624_y_17408.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18528_y_16544.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19232_y_16608.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18464_y_17024.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18752_y_16736.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_17568_y_15360.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_16640.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18912_y_16416.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_17280.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19104_y_17088.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19168_y_16544.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19008_y_15968.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_16288.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18560_y_17440.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16672_y_24608.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16448_y_24352.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18336_y_16032.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18880_y_16384.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19360_y_16768.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18624_y_17184.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18848_y_16832.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19200_y_16992.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18944_y_17152.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_16160.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_19072_y_17504.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_15936.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_16608_y_24960.png',\n", - " 'patches/patient_010_node_4/patch_patient_010_node_4_x_18496_y_16672.png',\n", - " ...]" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" } ], "source": [ @@ -1779,56 +824,105 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ - "import copy\n", - "full_dataset_camelyon17 = copy.deepcopy(full_dataset)" + "full_dataset.metadata_fields\n", + "config = config_encode\n", + "#config_encode.groupby_fields\n", + "\n", + "train_grouper = CombinatorialGrouper(\n", + " dataset=full_dataset,\n", + " groupby_fields=config.groupby_fields)" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 20, "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "image_base None\n" - ] + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "supported.datasets[config_encode.dataset]\n", - "print(config_camelyon.train_transform, config_encode.train_transform)" + "config_encode.eval_splits" ] }, { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": {}, - "outputs": [], - "source": [] + "source": [ + "# Train/eval" + ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 15, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train data...\n", + " celltype = H1-hESC: n = 507309\n", + " celltype = HCT116: n = 506458\n", + " celltype = HeLa-S3: n = 509974\n", + " celltype = HepG2: n = 503007\n", + " celltype = K562: n = 506847\n", + " celltype = A549: n = 0\n", + " celltype = GM12878: n = 0\n", + "Validation (ID) data...\n", + " celltype = H1-hESC: n = 433473\n", + " celltype = HCT116: n = 431398\n", + " celltype = HeLa-S3: n = 435455\n", + " celltype = HepG2: n = 433039\n", + " celltype = K562: n = 430163\n", + " celltype = A549: n = 0\n", + " celltype = GM12878: n = 0\n", + "Test data...\n", + " celltype = H1-hESC: n = 0\n", + " celltype = HCT116: n = 0\n", + " celltype = HeLa-S3: n = 0\n", + " celltype = HepG2: n = 0\n", + " celltype = K562: n = 0\n", + " celltype = A549: n = 0\n", + " celltype = GM12878: n = 437124\n", + "Validation (OOD) data...\n", + " celltype = H1-hESC: n = 0\n", + " celltype = HCT116: n = 0\n", + " celltype = HeLa-S3: n = 0\n", + " celltype = HepG2: n = 0\n", + " celltype = K562: n = 0\n", + " celltype = A549: n = 433986\n", + " celltype = GM12878: n = 0\n" + ] + }, + { + "ename": "ValueError", + "evalue": "Model not recognized.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 60\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 61\u001b[0m \u001b[0;31m## Initialize algorithm\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 62\u001b[0;31m algorithm = initialize_algorithm(\n\u001b[0m\u001b[1;32m 63\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 64\u001b[0m \u001b[0mdatasets\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdatasets\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/wilds/examples/algorithms/initializer.py\u001b[0m in \u001b[0;36minitialize_algorithm\u001b[0;34m(config, datasets, train_grouper)\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0malgorithm\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m'ERM'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 32\u001b[0;31m algorithm = ERM(\n\u001b[0m\u001b[1;32m 33\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[0md_out\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0md_out\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/wilds/examples/algorithms/ERM.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, config, d_out, grouper, loss, metric, n_train_steps)\u001b[0m\n\u001b[1;32m 6\u001b[0m def __init__(self, config, d_out, grouper, loss,\n\u001b[1;32m 7\u001b[0m metric, n_train_steps):\n\u001b[0;32m----> 8\u001b[0;31m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minitialize_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0md_out\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 9\u001b[0m \u001b[0;31m# initialize module\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m super().__init__(\n", + "\u001b[0;32m~/wilds/examples/models/initializer.py\u001b[0m in \u001b[0;36minitialize_model\u001b[0;34m(config, d_out)\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mGINVirtual\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnum_tasks\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0md_out\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 30\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Model not recognized.'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 31\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mValueError\u001b[0m: Model not recognized." + ] + } + ], "source": [ - "train_grouper = CombinatorialGrouper(\n", - " dataset=full_dataset,\n", - " groupby_fields=config.groupby_fields)\n", - "\n", "datasets = defaultdict(dict)\n", "for split in full_dataset.split_dict.keys():\n", " if split=='train':\n", @@ -1898,81 +992,26 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']\n", - "val_celltype = ['A549']\n", - "test_celltype = ['GM12878']\n", - "all_celltypes = train_celltypes + val_celltype + test_celltype\n", - "\n", - "metadata_map = {}\n", - "metadata_map['chr'] = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX']\n", - "metadata_map['celltype'] = all_celltypes\n", - "\n", - "_split_dict = {\n", - " 'train': 0,\n", - " 'val-id': 1,\n", - " 'test': 2,\n", - " 'val-ood': 3\n", - "}\n", - "_split_names = {\n", - " 'train': 'Train',\n", - " 'val-id': 'Validation (ID)',\n", - " 'test': 'Test',\n", - " 'val-ood': 'Validation (OOD)'\n", - "}\n", - "_split_scheme = 'standard'" + "for " ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "('H1-hESC', 25.299736976623535)\n", - "('HCT116', 49.68733310699463)\n", - "('HeLa-S3', 74.65905213356018)\n", - "('HepG2', 99.33112812042236)\n", - "('K562', 124.1327919960022)\n", - "('A549', 149.19999814033508)\n", - "('GM12878', 174.0277030467987)\n" - ] - } - ], - "source": [ - "itime = time.time()\n", - "sequence_filename = os.path.join(data_dir, 'sequence.npz')\n", - "seq_arr = np.load(sequence_filename)\n", - "print(time.time() - itime)\n", - "\n", - "itime = time.time()\n", - "_seq_bp = {}\n", - "for chrom in seq_arr:\n", - " _seq_bp[chrom] = seq_arr[chrom]\n", - " print(chrom, time.time() - itime)\n", - "itime = time.time()\n", - "_dnase_allcelltypes = {}\n", - "for ct in all_celltypes:\n", - " dnase_filename = os.path.join(data_dir, '{}_dnase.npz'.format(ct))\n", - " dnase_npz_file = np.load(dnase_filename)\n", - " _dnase_allcelltypes[ct] = {}\n", - " for chrom in _seq_bp:\n", - " _dnase_allcelltypes[ct][chrom] = dnase_npz_file[chrom]\n", - " print(ct, time.time() - itime)" - ] + "outputs": [], + "source": [] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": {}, - "source": [ - "# Train/eval" - ] + "outputs": [], + "source": [] }, { "cell_type": "code", diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index f8b66f25..184da7cd 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -108,6 +108,20 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): metadata_df['y'] = y_array self._metadata_df = metadata_df[non_ambig_mask] + samp_ndces = [] + itime = time.time() + for ct in self._all_celltypes: + neg_msk = np.logical_and((self._metadata_df['celltype'] == ct), (self._metadata_df['y'] == 0)) + pos_msk = np.logical_and((self._metadata_df['celltype'] == ct), (self._metadata_df['y'] == 1)) + neg_ndces = np.where(neg_msk)[0] + pos_ndces = np.where(pos_msk)[0] + np.random.seed(42) + samp_neg_ndces = np.random.choice(neg_ndces, size=len(pos_ndces), replace=False) + samp_ndces.extend(samp_neg_ndces) + samp_ndces.extend(pos_ndces) + print(ct, time.time() - itime) + self._metadata_df = self._metadata_df.iloc[samp_ndces, :] + train_regions_mask = np.isin(self._metadata_df['chr'], self._train_chroms) val_regions_mask = np.isin(self._metadata_df['chr'], self._test_chroms) train_celltype_mask = np.isin(self._metadata_df['celltype'], self._train_celltypes) From d67ca4aa72a70d72c0cc865a5e0f2978acd2f10d Mon Sep 17 00:00:00 2001 From: aikanor Date: Fri, 26 Feb 2021 07:39:22 -0800 Subject: [PATCH 070/244] adding new architecture --- dataset_preprocessing/encode-tfbs/README.md | 1 + examples/configs/datasets.py | 4 +- examples/configs/model.py | 2 +- examples/models/CNN_genome.py | 127 ++- examples/sbox_run_expt.ipynb | 1026 +++++++++++++++---- wilds/datasets/encodetfbs_dataset.py | 13 +- 6 files changed, 926 insertions(+), 247 deletions(-) diff --git a/dataset_preprocessing/encode-tfbs/README.md b/dataset_preprocessing/encode-tfbs/README.md index 616d4cb5..bf3f92c6 100644 --- a/dataset_preprocessing/encode-tfbs/README.md +++ b/dataset_preprocessing/encode-tfbs/README.md @@ -16,3 +16,4 @@ 5. Download the labels from the challenge into a label directory created for this purpose: - The training labels from https://www.synapse.org/#!Synapse:syn7413983 for the relevant transcription factor (e.g. https://www.synapse.org/#!Synapse:syn7415202 for the TF MAX). - The validation labels from https://www.synapse.org/#!Synapse:syn8441154 for the relevant transcription factor (e.g. https://www.synapse.org/#!Synapse:syn8442103 for the TF MAX). + - (Optional) The validation labels for the challenge's evaluation cell type from https://www.synapse.org/#!Synapse:syn8442975 for the relevant transcription factor (generally primary liver cells, e.g. https://www.synapse.org/#!Synapse:syn8443021 for the TF MAX). diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index 4d941653..144954b6 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -107,7 +107,7 @@ }, 'encode-tfbs': { 'split_scheme': 'official', - 'model': 'beagle', + 'model': 'leopard', 'model_kwargs': {'pretrained': False}, 'train_transform': None, 'eval_transform': None, @@ -121,7 +121,7 @@ 'batch_size': 64, 'lr': 0.001, 'weight_decay': 0.01, - 'n_epochs': 1, + 'n_epochs': 5, 'n_groups_per_batch': 2, 'algo_log_metric': 'accuracy', # 'irm_lambda': 1.0, diff --git a/examples/configs/model.py b/examples/configs/model.py index af37ab8e..f4eb8779 100644 --- a/examples/configs/model.py +++ b/examples/configs/model.py @@ -37,5 +37,5 @@ 'target_resolution': (224, 224), }, 'logistic_regression': {}, - 'beagle': {}, + 'leopard': {}, } diff --git a/examples/models/CNN_genome.py b/examples/models/CNN_genome.py index 1c65b567..3efdba20 100644 --- a/examples/models/CNN_genome.py +++ b/examples/models/CNN_genome.py @@ -4,51 +4,82 @@ import torch.nn as nn import torch.nn.functional as F -class Beagle(nn.Module): - """ - Neural net models over genomic sequence. Adapted from https://github.com/kundajelab/ChromDragoNN - - Input: - - s (Tensor): float torch tensor of shape (N, 5, 1000, 1) with batch size N. - - Output: - - prediction (Tensor): float torch tensor of shape (N, ) - """ - def __init__(self): - super(Beagle, self).__init__() - - self.dropout = 0.3 - self.num_cell_types = 1 - self.conv1 = nn.Conv2d(5, 300, (19, 1), stride = (1, 1), padding=(9,0)) - self.conv2 = nn.Conv2d(300, 200, (11, 1), stride = (1, 1), padding = (5,0)) - self.conv3 = nn.Conv2d(200, 200, (7, 1), stride = (1, 1), padding = (4,0)) - self.bn1 = nn.BatchNorm2d(300) - self.bn2 = nn.BatchNorm2d(200) - self.bn3 = nn.BatchNorm2d(200) - self.maxpool1 = nn.MaxPool2d((3, 1)) - self.maxpool2 = nn.MaxPool2d((4, 1)) - self.maxpool3 = nn.MaxPool2d((4, 1)) - - self.fc1 = nn.Linear(4200, 1000) - self.bn4 = nn.BatchNorm1d(1000) - - self.fc2 = nn.Linear(1000, 1000) - self.bn5 = nn.BatchNorm1d(1000) - - self.fc3 = nn.Linear(1000, self.num_cell_types) - - def forward(self, s): - s = s.permute(0, 2, 1).contiguous() # batch_size x 5 x 1000 - s = s.view(-1, 5, 1000, 1) # batch_size x 5 x 1000 x 1 [5 channels] - s = self.maxpool1(F.relu(self.bn1(self.conv1(s)))) # batch_size x 300 x 333 x 1 - s = self.maxpool2(F.relu(self.bn2(self.conv2(s)))) # batch_size x 200 x 83 x 1 - s = self.maxpool3(F.relu(self.bn3(self.conv3(s)))) # batch_size x 200 x 21 x 1 - s = s.view(-1, 4200) - conv_out = s - - s = F.dropout(F.relu(self.bn4(self.fc1(s))), p=self.dropout, training=self.training) # batch_size x 1000 - s = F.dropout(F.relu(self.bn5(self.fc2(s))), p=self.dropout, training=self.training) # batch_size x 1000 - - prediction = self.fc3(s) - - return s #, conv_out + + +def double_conv(in_channels, out_channels): + return nn.Sequential( + nn.Conv1d(in_channels, out_channels, 7, padding=3), + nn.BatchNorm1d(out_channels), + nn.ReLU(inplace=True), + nn.Conv1d(out_channels, out_channels, 7, padding=3), + nn.BatchNorm1d(out_channels), + nn.ReLU(inplace=True) + ) + + +class UNet(nn.Module): + + def __init__(self, n_class): + super().__init__() + + self.dconv_down1 = double_conv(6, 15) + self.dconv_down2 = double_conv(15, 22) + self.dconv_down3 = double_conv(22, 33) + self.dconv_down4 = double_conv(33, 49) + self.dconv_down5 = double_conv(49, 73) + self.dconv_down6 = double_conv(73, 109) + + self.maxpool = nn.MaxPool1d(2) + self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + + self.dconv_up5 = double_conv(73 + 109, 73) + self.dconv_up4 = double_conv(49 + 73, 49) + self.dconv_up3 = double_conv(33 + 49, 33) + self.dconv_up2 = double_conv(22 + 33, 22) + self.dconv_up1 = double_conv(15 + 22, 15) + + self.conv_last = nn.Conv2d(15, n_class, 1) + + + def forward(self, x): + conv1 = self.dconv_down1(x) + x = self.maxpool(conv1) + + conv2 = self.dconv_down2(x) + x = self.maxpool(conv2) + + conv3 = self.dconv_down3(x) + x = self.maxpool(conv3) + + conv4 = self.dconv_down4(x) + x = self.maxpool(conv4) + + conv5 = self.dconv_down5(x) + x = self.maxpool(conv5) + + x = self.dconv_down6(x) + + x = self.upsample(x) + x = torch.cat([x, conv5], dim=1) + + x = self.dconv_up5(x) + x = self.upsample(x) + x = torch.cat([x, conv4], dim=1) + + x = self.dconv_up4(x) + x = self.upsample(x) + x = torch.cat([x, conv3], dim=1) + + x = self.dconv_up3(x) + x = self.upsample(x) + x = torch.cat([x, conv2], dim=1) + + x = self.dconv_up2(x) + x = self.upsample(x) + x = torch.cat([x, conv1], dim=1) + + x = self.dconv_up1(x) + + out = self.conv_last(x) + + return out diff --git a/examples/sbox_run_expt.ipynb b/examples/sbox_run_expt.ipynb index e50f790b..2c56cdd6 100644 --- a/examples/sbox_run_expt.ipynb +++ b/examples/sbox_run_expt.ipynb @@ -11,7 +11,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 123, "metadata": {}, "outputs": [ { @@ -21,7 +21,7 @@ "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpsutil\u001b[0m\u001b[0;34m;\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpsutil\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mProcess\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgetpid\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmemory_info\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrss\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0;36m1024\u001b[0m \u001b[0;34m**\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpsutil\u001b[0m\u001b[0;34m;\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpsutil\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mProcess\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgetpid\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmemory_info\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrss\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0;36m1024\u001b[0m \u001b[0;34m**\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'psutil'" ] } @@ -32,9 +32,17 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:root:The OGB package is out of date. Your version is 1.2.4, while the latest version is 1.2.5.\n" + ] + } + ], "source": [ "import os, csv\n", "import time\n", @@ -59,16 +67,16 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" + "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" ] }, - "execution_count": 3, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } @@ -163,7 +171,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -175,12 +183,13 @@ "config_encode = parser.parse_args(argstr_encode.split())\n", "config_encode = populate_defaults(config_encode)\n", "\n", - "config = config_camelyon" + "config = config_camelyon\n", + "# config = config_encode" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -292,7 +301,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -312,30 +321,26 @@ }, { "cell_type": "code", - "execution_count": 30, - "metadata": { - "jupyter": { - "source_hidden": true - } - }, + "execution_count": 7, + "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "chr2 3.657395362854004\n", - "chr9 5.770605564117432\n", - "chr11 7.801896095275879\n", - "chr1 11.56663990020752\n", - "chr8 13.764073133468628\n", - "chr21 14.483267068862915\n", - "H1-hESC 20.850953817367554\n", - "HCT116 27.05355429649353\n", - "HeLa-S3 33.51919412612915\n", - "HepG2 39.89570116996765\n", - "K562 46.36982774734497\n", - "A549 52.82617139816284\n", - "GM12878 59.167165994644165\n" + "chr2 3.7666022777557373\n", + "chr9 5.9439966678619385\n", + "chr11 8.030796766281128\n", + "chr1 11.851332426071167\n", + "chr8 14.106642007827759\n", + "chr21 14.852506160736084\n", + "H1-hESC 14.853845119476318\n", + "HCT116 14.853914022445679\n", + "HeLa-S3 14.853951930999756\n", + "HepG2 14.853987216949463\n", + "K562 14.854026317596436\n", + "A549 14.8540620803833\n", + "GM12878 14.854098796844482\n" ] } ], @@ -371,7 +376,7 @@ "_all_celltypes = _train_celltypes + _val_celltype + _test_celltype\n", "\n", "_metadata_map = {}\n", - "_metadata_map['chr'] = _all_chroms #['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX']\n", + "_metadata_map['chr'] = _all_chroms\n", "_metadata_map['celltype'] = _all_celltypes\n", "\n", "# Get the splits\n", @@ -402,12 +407,35 @@ "\n", "_dnase_allcelltypes = {}\n", "for ct in _all_celltypes:\n", + " \"\"\"\n", " dnase_filename = os.path.join(_data_dir, '{}_dnase.npz'.format(ct))\n", " dnase_npz_contents = np.load(dnase_filename)\n", " _dnase_allcelltypes[ct] = {}\n", " for chrom in _all_chroms: #_seq_bp:\n", " _dnase_allcelltypes[ct][chrom] = dnase_npz_contents[chrom]\n", - " print(ct, time.time() - itime)\n", + " \"\"\"\n", + " _dnase_allcelltypes[ct] = 'DNASE.{}.fc.signal.bigwig'\n", + " print(ct, time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\"\\nfilter_msk = all_df['start'] >= 0\\nfilter_msk = all_df['start']%1000 == 0\\nall_df = all_df[filter_msk]\\n\"" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "itime = time.time()\n", "\n", "# Read in metadata dataframe from training+validation data\n", "train_regions_labeled = pd.read_csv(os.path.join(_data_dir, 'labels/{}.train.labels.tsv.gz'.format(_transcription_factor)), sep='\\t')\n", @@ -416,36 +444,384 @@ "val_df = val_regions_labeled[np.isin(val_regions_labeled['chr'], _test_chroms)]\n", "all_df = pd.concat([training_df, val_df])\n", "\n", - "# Filter by start/stop coordinate if needed (TODO: remove for final version)\n", - "# filter_msk = all_df['start'] >= 0\n", - "# filter_msk = all_df['start']%1000 == 0\n", - "# all_df = all_df[filter_msk]\n", + "print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "# train_regions_labeled.replace({'U': 0, 'B': 1, 'A': -1})\n", + "# a = \n", + "# np.random.choice(train_regions_labeled.shape[0], size=100000)\n", + "\n", + "v = val_regions_labeled.replace({'U': 0, 'B': 1, 'A': -1})\n", + "# seta = [full_dataset_encode.get_input(x) for x in a]\n", + "# seta[0].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(array([-7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7]),\n", + " array([ 189, 854, 3579, 11535, 35901, 126629, 621676,\n", + " 7944663, 67689, 13516, 6766, 3332, 3179, 1076,\n", + " 2427]))" + ] + }, + "execution_count": 40, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.unique(v[['A549', 'GM12878', 'H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']].sum(axis=1), return_counts=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + ":12: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " tc_chr['y'] = y_array\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "11.363114833831787\n", + "21.872379302978516\n", + "32.51760506629944\n", + "42.88175559043884\n", + "53.35902285575867\n", + "63.94557332992554\n", + "74.44822382926941\n", + "92.237633228302\n" + ] + } + ], + "source": [ + "itime = time.time()\n", "\n", + "# Get the y values, and remove ambiguous labels by default.\n", "pd_list = []\n", "for ct in _all_celltypes:\n", " tc_chr = all_df[['chr', 'start', 'stop', ct]]\n", " tc_chr.columns = ['chr', 'start', 'stop', 'y']\n", + " y_array = tc_chr['y'].replace({'U': 0, 'B': 1, 'A': -1}).values\n", + " \n", + " # Now filter out ambiguous labels\n", + " non_ambig_mask = (y_array != -1)\n", + " tc_chr['y'] = y_array\n", + " tc_chr = tc_chr[non_ambig_mask]\n", + " \n", " tc_chr.insert(len(tc_chr.columns), 'celltype', ct)\n", " pd_list.append(tc_chr)\n", + " print(time.time() - itime)\n", "metadata_df = pd.concat(pd_list)\n", "\n", - "# Get the y values, and remove ambiguous labels by default.\n", - "y_array = metadata_df['y'].replace({'U': 0, 'B': 1, 'A': -1}).values\n", - "non_ambig_mask = (y_array != -1)\n", - "metadata_df['y'] = y_array\n", - "_metadata_df = metadata_df[non_ambig_mask]" + "print(time.time() - itime)\n", + "\n", + "# y_array = metadata_df['y'].replace({'U': 0, 'B': 1, 'A': -1}).values\n", + "# non_ambig_mask = (y_array != -1)\n", + "# metadata_df['y'] = y_array\n", + "# _metadata_df = metadata_df[non_ambig_mask]\n", + "\n", + "# print(time.time() - itime)" ] }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 75, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
chrstartstopycelltype
2702470chr116008000H1-hESC
2702471chr116508500H1-hESC
2702472chr117009000H1-hESC
2702473chr117509500H1-hESC
2702474chr1180010000H1-hESC
..................
8843006chr81463632001463634000GM12878
8843007chr81463632501463634500GM12878
8843008chr81463633001463635000GM12878
8843009chr81463633501463635500GM12878
8843010chr81463634001463636000GM12878
\n", + "

131721055 rows × 5 columns

\n", + "
" + ], + "text/plain": [ + " chr start stop y celltype\n", + "2702470 chr11 600 800 0 H1-hESC\n", + "2702471 chr11 650 850 0 H1-hESC\n", + "2702472 chr11 700 900 0 H1-hESC\n", + "2702473 chr11 750 950 0 H1-hESC\n", + "2702474 chr11 800 1000 0 H1-hESC\n", + "... ... ... ... .. ...\n", + "8843006 chr8 146363200 146363400 0 GM12878\n", + "8843007 chr8 146363250 146363450 0 GM12878\n", + "8843008 chr8 146363300 146363500 0 GM12878\n", + "8843009 chr8 146363350 146363550 0 GM12878\n", + "8843010 chr8 146363400 146363600 0 GM12878\n", + "\n", + "[131721055 rows x 5 columns]" + ] + }, + "execution_count": 75, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "metadata_df\n", + "# tc_chr" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [], + "source": [ + "# window_size = 12800\n", + "# window_interval = window_size/2\n", + "# trl_mask = (train_regions_labeled['start']%window_interval == 0)\n", + "# train_regions_labeled[trl_mask]" + ] + }, + { + "cell_type": "code", + "execution_count": 68, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "686900" + ] + }, + "execution_count": 68, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "(metadata_df['y'] == 1).sum()\n", + "# pd_list[0][non_ambig_mask]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# celltype_msk.sum()\n", + "\n", + "np.unique(_metadata_df['chr'])\n", + "\n", + "# celltype_msk = (_metadata_df['celltype'] == ct)\n", + "# np.where(celltype_msk)[0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "_metadata_df" + ] + }, + { + "cell_type": "code", + "execution_count": 24, "metadata": { + "collapsed": true, "jupyter": { - "source_hidden": true + "outputs_hidden": true } }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "H1-hESC 32.5968804359436\n", + "H1-hESC 33.237690687179565\n", + "H1-hESC 37.01208806037903\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mpos_msk\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0m_metadata_df\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'y'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mct\u001b[0m \u001b[0;32min\u001b[0m \u001b[0m_all_celltypes\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0mcelltype_msk\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0m_metadata_df\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'celltype'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mct\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 8\u001b[0m \u001b[0mneg_ct_msk\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlogical_and\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcelltype_msk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mneg_msk\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0mpos_ct_msk\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlogical_and\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcelltype_msk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpos_msk\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/pandas/core/ops/common.py\u001b[0m in \u001b[0;36mnew_method\u001b[0;34m(self, other)\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0mother\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mitem_from_zerodim\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mother\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 64\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 65\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mmethod\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mother\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 66\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 67\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mnew_method\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/pandas/core/arraylike.py\u001b[0m in \u001b[0;36m__eq__\u001b[0;34m(self, other)\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0munpack_zerodim_and_defer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"__eq__\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__eq__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mother\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 29\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_cmp_method\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mother\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moperator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meq\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 30\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0munpack_zerodim_and_defer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"__ne__\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/pandas/core/series.py\u001b[0m in \u001b[0;36m_cmp_method\u001b[0;34m(self, other, op)\u001b[0m\n\u001b[1;32m 4946\u001b[0m \u001b[0mrvalues\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mextract_array\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mother\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mextract_numpy\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4947\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 4948\u001b[0;31m \u001b[0mres_values\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mops\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcomparison_op\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlvalues\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrvalues\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mop\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4949\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4950\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_construct_result\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mres_values\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mres_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/pandas/core/ops/array_ops.py\u001b[0m in \u001b[0;36mcomparison_op\u001b[0;34m(left, right, op)\u001b[0m\n\u001b[1;32m 241\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 242\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mis_object_dtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlvalues\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 243\u001b[0;31m \u001b[0mres_values\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcomp_method_OBJECT_ARRAY\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mop\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlvalues\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrvalues\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 244\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 245\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/pandas/core/ops/array_ops.py\u001b[0m in \u001b[0;36mcomp_method_OBJECT_ARRAY\u001b[0;34m(op, x, y)\u001b[0m\n\u001b[1;32m 53\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlibops\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvec_compare\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mravel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mravel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mop\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 55\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlibops\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscalar_compare\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mravel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mop\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 56\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 57\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "# Downsample negatives to balance each celltype\n", + "samp_ndces = []\n", + "itime = time.time()\n", + "neg_msk = (_metadata_df['y'] == 0)\n", + "pos_msk = (_metadata_df['y'] == 1)\n", + "for ct in _all_celltypes:\n", + " celltype_msk = (_metadata_df['celltype'] == ct)\n", + " neg_ct_msk = np.logical_and(celltype_msk, neg_msk)\n", + " pos_ct_msk = np.logical_and(celltype_msk, pos_msk)\n", + " print(ct, time.time() - itime)\n", + " neg_ndces = np.where(neg_ct_msk)[0]\n", + " pos_ndces = np.where(pos_ct_msk)[0]\n", + " print(ct, time.time() - itime)\n", + " np.random.seed(42)\n", + " samp_neg_ndces = np.random.choice(neg_ndces, size=len(pos_ndces), replace=False)\n", + " samp_ndces.extend(samp_neg_ndces)\n", + " samp_ndces.extend(pos_ndces)\n", + " print(ct, time.time() - itime)\n", + "_metadata_df = _metadata_df.iloc[samp_ndces, :]" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, "outputs": [], "source": [ + "# Downsample negatives to balance each celltype\n", "samp_ndces = []\n", "itime = time.time()\n", "for ct in _all_celltypes:\n", @@ -458,7 +834,6 @@ " samp_ndces.extend(samp_neg_ndces)\n", " samp_ndces.extend(pos_ndces)\n", " print(ct, time.time() - itime)\n", - "\n", "_metadata_df = _metadata_df.iloc[samp_ndces, :]\n", "\n", "train_regions_mask = np.isin(_metadata_df['chr'], _train_chroms)\n", @@ -503,8 +878,12 @@ }, { "cell_type": "code", - "execution_count": 23, - "metadata": {}, + "execution_count": 19, + "metadata": { + "jupyter": { + "source_hidden": true + } + }, "outputs": [], "source": [ "import os, time\n", @@ -542,8 +921,8 @@ " self._y_size = 1\n", " self._n_classes = 2\n", " \n", - " self._train_chroms = ['chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX']\n", - " # self._train_chroms = ['chr2', 'chr9', 'chr11']\n", + " # self._train_chroms = ['chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX']\n", + " self._train_chroms = ['chr2', 'chr9', 'chr11']\n", " self._test_chroms = ['chr1', 'chr8', 'chr21']\n", " self._transcription_factor = 'MAX'\n", " self._train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']\n", @@ -553,7 +932,7 @@ " self._all_celltypes = self._train_celltypes + self._val_celltype + self._test_celltype\n", " \n", " self._metadata_map = {}\n", - " self._metadata_map['chr'] = self._all_chroms #['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX']\n", + " self._metadata_map['chr'] = self._all_chroms\n", " self._metadata_map['celltype'] = self._all_celltypes\n", " \n", " # Get the splits\n", @@ -696,43 +1075,38 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, + "execution_count": 20, + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + } + }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "chr2 3.718320846557617\n", - "chr3 6.73882269859314\n", - "chr4 9.651247501373291\n", - "chr5 12.439628839492798\n", - "chr6 15.05026388168335\n", - "chr7 17.475954055786133\n", - "chr9 19.6206693649292\n", - "chr10 21.68758535385132\n", - "chr11 23.74817419052124\n", - "chr12 25.81403160095215\n", - "chr13 27.559557676315308\n", - "chr14 29.18643832206726\n", - "chr15 30.739391565322876\n", - "chr16 32.11144256591797\n", - "chr17 33.348127126693726\n", - "chr18 34.53834342956543\n", - "chr19 35.434733629226685\n", - "chr20 36.399296283721924\n", - "chr22 37.1924102306366\n", - "chrX 39.56284308433533\n", - "chr1 43.3526566028595\n", - "chr8 45.583492040634155\n", - "chr21 46.311339378356934\n", - "H1-hESC 66.45292735099792\n", - "HCT116 86.06067085266113\n", - "HeLa-S3 106.47142815589905\n", - "HepG2 126.59437656402588\n", - "K562 146.93650436401367\n", - "A549 167.19306707382202\n", - "GM12878 187.4349775314331\n" + "chr2 3.7390823364257812\n", + "chr9 5.909312963485718\n", + "chr11 8.020122051239014\n", + "chr1 11.871179103851318\n", + "chr8 14.147786140441895\n", + "chr21 14.896430492401123\n", + "H1-hESC 21.391544818878174\n", + "HCT116 27.753155946731567\n", + "HeLa-S3 34.33590316772461\n", + "HepG2 40.81141257286072\n", + "K562 47.39495897293091\n", + "A549 54.245203495025635\n", + "GM12878 60.693068742752075\n", + "H1-hESC 16.79085922241211\n", + "HCT116 33.788668394088745\n", + "HeLa-S3 51.1968936920166\n", + "HepG2 68.32299137115479\n", + "K562 85.74746584892273\n", + "A549 103.05137896537781\n", + "GM12878 120.52022075653076\n" ] } ], @@ -746,24 +1120,33 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 39, + "metadata": {}, + "outputs": [], + "source": [ + "a = np.random.choice(1210796, size=128)\n", + "seta = [full_dataset_encode.get_input(x) for x in a]\n", + "seta[0].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "(array(['A549', 'GM12878', 'H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562'],\n", - " dtype=object),\n", - " array([ 5118, 1702, 8460, 12806, 8348, 11774, 12518]))" + "(array([0, 1, 2, 3]), array([2804551, 498433, 34145, 100851]))" ] }, - "execution_count": 20, + "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "np.unique(full_dataset_encode._metadata_df['celltype'], return_counts=True)" + "np.unique(full_dataset_encode._metadata_df['split'], return_counts=True)" ] }, { @@ -810,26 +1193,20 @@ "print(np.unique(full_dataset.y_array.numpy(), return_counts=True))\n", "print(np.unique(full_dataset._metadata_df['split'], return_counts=True))\n", "\n", - "#full_dataset._input_array" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# pd.read_csv(os.path.join('data/camelyon17_v1.0/metadata.csv'), index_col=0, dtype={'patient': 'str'})" + "#full_dataset._input_array\n", + "\n", + "#full_dataset_encode._seq_bp['chr11'].shape\n", + "full_dataset_encode._dnase_allcelltypes['HCT116']['chr11'].shape" ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 68, "metadata": {}, "outputs": [], "source": [ "full_dataset.metadata_fields\n", - "config = config_encode\n", + "config = config_camelyon\n", "#config_encode.groupby_fields\n", "\n", "train_grouper = CombinatorialGrouper(\n", @@ -839,86 +1216,51 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 118, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[]" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "config_encode.eval_splits" + "# full_dataset = copy.deepcopy(full_dataset_encode)\n", + "full_dataset = copy.deepcopy(full_dataset_camelyon17)\n", + "# full_dataset_camelyon17.split_dict" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "markdown", "metadata": {}, "source": [ - "# Train/eval" + "# Initialize algorithm" ] }, { "cell_type": "code", - "execution_count": 15, - "metadata": {}, + "execution_count": 120, + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + } + }, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Train data...\n", - " celltype = H1-hESC: n = 507309\n", - " celltype = HCT116: n = 506458\n", - " celltype = HeLa-S3: n = 509974\n", - " celltype = HepG2: n = 503007\n", - " celltype = K562: n = 506847\n", - " celltype = A549: n = 0\n", - " celltype = GM12878: n = 0\n", - "Validation (ID) data...\n", - " celltype = H1-hESC: n = 433473\n", - " celltype = HCT116: n = 431398\n", - " celltype = HeLa-S3: n = 435455\n", - " celltype = HepG2: n = 433039\n", - " celltype = K562: n = 430163\n", - " celltype = A549: n = 0\n", - " celltype = GM12878: n = 0\n", - "Test data...\n", - " celltype = H1-hESC: n = 0\n", - " celltype = HCT116: n = 0\n", - " celltype = HeLa-S3: n = 0\n", - " celltype = HepG2: n = 0\n", - " celltype = K562: n = 0\n", - " celltype = A549: n = 0\n", - " celltype = GM12878: n = 437124\n", - "Validation (OOD) data...\n", - " celltype = H1-hESC: n = 0\n", - " celltype = HCT116: n = 0\n", - " celltype = HeLa-S3: n = 0\n", - " celltype = HepG2: n = 0\n", - " celltype = K562: n = 0\n", - " celltype = A549: n = 433986\n", - " celltype = GM12878: n = 0\n" - ] - }, { "ename": "ValueError", - "evalue": "Model not recognized.", + "evalue": "I/O operation on closed file", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 60\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 61\u001b[0m \u001b[0;31m## Initialize algorithm\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 62\u001b[0;31m algorithm = initialize_algorithm(\n\u001b[0m\u001b[1;32m 63\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 64\u001b[0m \u001b[0mdatasets\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdatasets\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/wilds/examples/algorithms/initializer.py\u001b[0m in \u001b[0;36minitialize_algorithm\u001b[0;34m(config, datasets, train_grouper)\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0malgorithm\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m'ERM'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 32\u001b[0;31m algorithm = ERM(\n\u001b[0m\u001b[1;32m 33\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[0md_out\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0md_out\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/wilds/examples/algorithms/ERM.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, config, d_out, grouper, loss, metric, n_train_steps)\u001b[0m\n\u001b[1;32m 6\u001b[0m def __init__(self, config, d_out, grouper, loss,\n\u001b[1;32m 7\u001b[0m metric, n_train_steps):\n\u001b[0;32m----> 8\u001b[0;31m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minitialize_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0md_out\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 9\u001b[0m \u001b[0;31m# initialize module\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m super().__init__(\n", - "\u001b[0;32m~/wilds/examples/models/initializer.py\u001b[0m in \u001b[0;36minitialize_model\u001b[0;34m(config, d_out)\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mGINVirtual\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnum_tasks\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0md_out\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 30\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Model not recognized.'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 31\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mValueError\u001b[0m: Model not recognized." + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 57\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[0mlog_grouper\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_grouper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 59\u001b[0;31m \u001b[0mlog_group_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdatasets\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlog_grouper\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlogger\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 60\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 61\u001b[0m \u001b[0;31m## Initialize algorithm\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/wilds/examples/utils.py\u001b[0m in \u001b[0;36mlog_group_data\u001b[0;34m(datasets, grouper, logger)\u001b[0m\n\u001b[1;32m 68\u001b[0m \u001b[0mname\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'name'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 69\u001b[0m \u001b[0mdataset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'dataset'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 70\u001b[0;31m \u001b[0mlogger\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwrite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'{name} data...\\n'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 71\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mgrouper\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[0mlogger\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwrite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf' n = {len(dataset)}\\n'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/wilds/examples/utils.py\u001b[0m in \u001b[0;36mwrite\u001b[0;34m(self, msg)\u001b[0m\n\u001b[1;32m 97\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 98\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mwrite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 99\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconsole\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwrite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 100\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfile\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 101\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfile\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwrite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/ipykernel/iostream.py\u001b[0m in \u001b[0;36mwrite\u001b[0;34m(self, string)\u001b[0m\n\u001b[1;32m 392\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 393\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpub_thread\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 394\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'I/O operation on closed file'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 395\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 396\u001b[0m \u001b[0;31m# Make sure that we're handling unicode\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mValueError\u001b[0m: I/O operation on closed file" ] } ], @@ -992,26 +1334,76 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 135, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "device(type='cuda', index=0)" + ] + }, + "execution_count": 135, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "for " + "algorithm.device\n", + "# datasets['train']['loader']" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 134, "metadata": {}, - "outputs": [], - "source": [] + "outputs": [ + { + "ename": "RuntimeError", + "evalue": "CUDA out of memory. Tried to allocate 14.00 MiB (GPU 0; 11.93 GiB total capacity; 10.94 GiB already allocated; 5.06 MiB free; 11.32 GiB reserved in total by PyTorch)", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# datasets['train']['dataset'].size()\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0malgorithm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0malgorithm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 725\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 728\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 729\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torchvision/models/densenet.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 190\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 191\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 192\u001b[0;31m \u001b[0mfeatures\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfeatures\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 193\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfeatures\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minplace\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 194\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madaptive_avg_pool2d\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 725\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 728\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 729\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/modules/container.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 115\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 117\u001b[0;31m \u001b[0minput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 118\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 119\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 725\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 728\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 729\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torchvision/models/densenet.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, init_features)\u001b[0m\n\u001b[1;32m 109\u001b[0m \u001b[0mfeatures\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0minit_features\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 110\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlayer\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 111\u001b[0;31m \u001b[0mnew_features\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlayer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfeatures\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 112\u001b[0m \u001b[0mfeatures\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnew_features\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfeatures\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 725\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 728\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 729\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torchvision/models/densenet.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 82\u001b[0m \u001b[0mbottleneck_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcall_checkpoint_bottleneck\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprev_features\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 83\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 84\u001b[0;31m \u001b[0mbottleneck_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbn_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprev_features\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 85\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[0mnew_features\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconv2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelu2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbottleneck_output\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torchvision/models/densenet.py\u001b[0m in \u001b[0;36mbn_function\u001b[0;34m(self, inputs)\u001b[0m\n\u001b[1;32m 39\u001b[0m \u001b[0;31m# type: (List[Tensor]) -> Tensor\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 40\u001b[0m \u001b[0mconcated_features\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 41\u001b[0;31m \u001b[0mbottleneck_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconv1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelu1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconcated_features\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# noqa: T484\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 42\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mbottleneck_output\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 725\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 728\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 729\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/modules/batchnorm.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 129\u001b[0m \u001b[0mused\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mnormalization\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m \u001b[0;32min\u001b[0m \u001b[0meval\u001b[0m \u001b[0mmode\u001b[0m \u001b[0mwhen\u001b[0m \u001b[0mbuffers\u001b[0m \u001b[0mare\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 130\u001b[0m \"\"\"\n\u001b[0;32m--> 131\u001b[0;31m return F.batch_norm(\n\u001b[0m\u001b[1;32m 132\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 133\u001b[0m \u001b[0;31m# If buffers are not to be tracked, ensure that they won't be updated\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36mbatch_norm\u001b[0;34m(input, running_mean, running_var, weight, bias, training, momentum, eps)\u001b[0m\n\u001b[1;32m 2054\u001b[0m \u001b[0m_verify_batch_size\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2055\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2056\u001b[0;31m return torch.batch_norm(\n\u001b[0m\u001b[1;32m 2057\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrunning_mean\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrunning_var\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2058\u001b[0m \u001b[0mtraining\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmomentum\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meps\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackends\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcudnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menabled\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mRuntimeError\u001b[0m: CUDA out of memory. Tried to allocate 14.00 MiB (GPU 0; 11.93 GiB total capacity; 10.94 GiB already allocated; 5.06 MiB free; 11.32 GiB reserved in total by PyTorch)" + ] + } + ], + "source": [ + "# datasets['train']['dataset'].size()\n", + "algorithm.model(x.to(algorithm.device))" + ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 131, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "for batch in datasets['train']['loader']:\n", + " x, y_true, metadata = batch\n", + " break" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Train" + ] }, { "cell_type": "code", @@ -1097,7 +1489,28 @@ }, { "cell_type": "code", - "execution_count": 78, + "execution_count": 126, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 126, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "for b in full_dataset:\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": 33, "metadata": {}, "outputs": [], "source": [ @@ -1130,12 +1543,12 @@ "\n", " self.dropout = 0.3\n", " self.num_cell_types = 1\n", - " self.conv1 = nn.Conv2d(5, 300, (19, 1), stride = (1, 1), padding=(9,0))\n", - " self.conv2 = nn.Conv2d(300, 200, (11, 1), stride = (1, 1), padding = (5,0))\n", - " self.conv3 = nn.Conv2d(200, 200, (7, 1), stride = (1, 1), padding = (4,0))\n", - " self.bn1 = nn.BatchNorm2d(300)\n", - " self.bn2 = nn.BatchNorm2d(200)\n", - " self.bn3 = nn.BatchNorm2d(200)\n", + " self.conv1 = nn.Conv2d(5, 50, (19, 1), stride = (1, 1), padding=(9,0))\n", + " self.conv2 = nn.Conv2d(50, 50, (11, 1), stride = (1, 1), padding = (5,0))\n", + " self.conv3 = nn.Conv2d(50, 50, (7, 1), stride = (1, 1), padding = (4,0))\n", + " self.bn1 = nn.BatchNorm2d(50)\n", + " self.bn2 = nn.BatchNorm2d(50)\n", + " self.bn3 = nn.BatchNorm2d(50)\n", " self.maxpool1 = nn.MaxPool2d((3, 1))\n", " self.maxpool2 = nn.MaxPool2d((4, 1))\n", " self.maxpool3 = nn.MaxPool2d((4, 1))\n", @@ -1167,29 +1580,242 @@ }, { "cell_type": "code", - "execution_count": 86, + "execution_count": 124, "metadata": {}, + "outputs": [], + "source": [ + "def double_conv(in_channels, out_channels): \n", + " return nn.Sequential(\n", + " nn.Conv1d(in_channels, out_channels, 7, padding=3), \n", + " nn.BatchNorm1d(out_channels), \n", + " nn.ReLU(inplace=True),\n", + " nn.Conv1d(out_channels, out_channels, 7, padding=3), \n", + " nn.BatchNorm1d(out_channels), \n", + " nn.ReLU(inplace=True)\n", + " )\n", + "\n", + "\n", + "class UNet(nn.Module):\n", + "\n", + " def __init__(self, n_class):\n", + " super().__init__()\n", + " \n", + " self.dconv_down1 = double_conv(6, 15)\n", + " self.dconv_down2 = double_conv(15, 22)\n", + " self.dconv_down3 = double_conv(22, 33)\n", + " self.dconv_down4 = double_conv(33, 49)\n", + " self.dconv_down5 = double_conv(49, 73)\n", + " self.dconv_down6 = double_conv(73, 109)\n", + "\n", + " self.maxpool = nn.MaxPool1d(2)\n", + " self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) \n", + " \n", + " self.dconv_up5 = double_conv(73 + 109, 73)\n", + " self.dconv_up4 = double_conv(49 + 73, 49)\n", + " self.dconv_up3 = double_conv(33 + 49, 33)\n", + " self.dconv_up2 = double_conv(22 + 33, 22)\n", + " self.dconv_up1 = double_conv(15 + 22, 15)\n", + " \n", + " self.conv_last = nn.Conv2d(15, n_class, 1)\n", + " \n", + " \n", + " def forward(self, x):\n", + " conv1 = self.dconv_down1(x)\n", + " x = self.maxpool(conv1)\n", + "\n", + " conv2 = self.dconv_down2(x)\n", + " x = self.maxpool(conv2)\n", + " \n", + " conv3 = self.dconv_down3(x)\n", + " x = self.maxpool(conv3)\n", + " \n", + " conv4 = self.dconv_down4(x)\n", + " x = self.maxpool(conv4)\n", + " \n", + " conv5 = self.dconv_down5(x)\n", + " x = self.maxpool(conv5)\n", + " \n", + " x = self.dconv_down6(x)\n", + " \n", + " x = self.upsample(x) \n", + " x = torch.cat([x, conv5], dim=1)\n", + " \n", + " x = self.dconv_up5(x)\n", + " x = self.upsample(x) \n", + " x = torch.cat([x, conv4], dim=1)\n", + " \n", + " x = self.dconv_up4(x)\n", + " x = self.upsample(x) \n", + " x = torch.cat([x, conv3], dim=1)\n", + " \n", + " x = self.dconv_up3(x)\n", + " x = self.upsample(x) \n", + " x = torch.cat([x, conv2], dim=1) \n", + "\n", + " x = self.dconv_up2(x)\n", + " x = self.upsample(x) \n", + " x = torch.cat([x, conv1], dim=1) \n", + " \n", + " x = self.dconv_up1(x)\n", + " \n", + " out = self.conv_last(x)\n", + " \n", + " return out" + ] + }, + { + "cell_type": "code", + "execution_count": 125, + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + } + }, "outputs": [ { "data": { "text/plain": [ - "[('nnet.0.weight', 33280),\n", - " ('nnet.0.bias', 320),\n", - " ('bdlstm.0.weight_ih_l0', 409600),\n", - " ('bdlstm.0.weight_hh_l0', 409600),\n", - " ('bdlstm.0.bias_ih_l0', 1280),\n", - " ('bdlstm.0.bias_hh_l0', 1280),\n", - " ('bdlstm.0.weight_ih_l0_reverse', 409600),\n", - " ('bdlstm.0.weight_hh_l0_reverse', 409600),\n", - " ('bdlstm.0.bias_ih_l0_reverse', 1280),\n", - " ('bdlstm.0.bias_hh_l0_reverse', 1280),\n", - " ('classifier.1.weight', 592000),\n", - " ('classifier.1.bias', 925),\n", - " ('classifier.3.weight', 4625),\n", - " ('classifier.3.bias', 5)]" + "[('dconv_down1.0.weight', 630),\n", + " ('dconv_down1.0.bias', 15),\n", + " ('dconv_down1.1.weight', 15),\n", + " ('dconv_down1.1.bias', 15),\n", + " ('dconv_down1.3.weight', 1575),\n", + " ('dconv_down1.3.bias', 15),\n", + " ('dconv_down1.4.weight', 15),\n", + " ('dconv_down1.4.bias', 15),\n", + " ('dconv_down2.0.weight', 2310),\n", + " ('dconv_down2.0.bias', 22),\n", + " ('dconv_down2.1.weight', 22),\n", + " ('dconv_down2.1.bias', 22),\n", + " ('dconv_down2.3.weight', 3388),\n", + " ('dconv_down2.3.bias', 22),\n", + " ('dconv_down2.4.weight', 22),\n", + " ('dconv_down2.4.bias', 22),\n", + " ('dconv_down3.0.weight', 5082),\n", + " ('dconv_down3.0.bias', 33),\n", + " ('dconv_down3.1.weight', 33),\n", + " ('dconv_down3.1.bias', 33),\n", + " ('dconv_down3.3.weight', 7623),\n", + " ('dconv_down3.3.bias', 33),\n", + " ('dconv_down3.4.weight', 33),\n", + " ('dconv_down3.4.bias', 33),\n", + " ('dconv_down4.0.weight', 11319),\n", + " ('dconv_down4.0.bias', 49),\n", + " ('dconv_down4.1.weight', 49),\n", + " ('dconv_down4.1.bias', 49),\n", + " ('dconv_down4.3.weight', 16807),\n", + " ('dconv_down4.3.bias', 49),\n", + " ('dconv_down4.4.weight', 49),\n", + " ('dconv_down4.4.bias', 49),\n", + " ('dconv_down5.0.weight', 25039),\n", + " ('dconv_down5.0.bias', 73),\n", + " ('dconv_down5.1.weight', 73),\n", + " ('dconv_down5.1.bias', 73),\n", + " ('dconv_down5.3.weight', 37303),\n", + " ('dconv_down5.3.bias', 73),\n", + " ('dconv_down5.4.weight', 73),\n", + " ('dconv_down5.4.bias', 73),\n", + " ('dconv_down6.0.weight', 55699),\n", + " ('dconv_down6.0.bias', 109),\n", + " ('dconv_down6.1.weight', 109),\n", + " ('dconv_down6.1.bias', 109),\n", + " ('dconv_down6.3.weight', 83167),\n", + " ('dconv_down6.3.bias', 109),\n", + " ('dconv_down6.4.weight', 109),\n", + " ('dconv_down6.4.bias', 109),\n", + " ('dconv_up5.0.weight', 93002),\n", + " ('dconv_up5.0.bias', 73),\n", + " ('dconv_up5.1.weight', 73),\n", + " ('dconv_up5.1.bias', 73),\n", + " ('dconv_up5.3.weight', 37303),\n", + " ('dconv_up5.3.bias', 73),\n", + " ('dconv_up5.4.weight', 73),\n", + " ('dconv_up5.4.bias', 73),\n", + " ('dconv_up4.0.weight', 41846),\n", + " ('dconv_up4.0.bias', 49),\n", + " ('dconv_up4.1.weight', 49),\n", + " ('dconv_up4.1.bias', 49),\n", + " ('dconv_up4.3.weight', 16807),\n", + " ('dconv_up4.3.bias', 49),\n", + " ('dconv_up4.4.weight', 49),\n", + " ('dconv_up4.4.bias', 49),\n", + " ('dconv_up3.0.weight', 18942),\n", + " ('dconv_up3.0.bias', 33),\n", + " ('dconv_up3.1.weight', 33),\n", + " ('dconv_up3.1.bias', 33),\n", + " ('dconv_up3.3.weight', 7623),\n", + " ('dconv_up3.3.bias', 33),\n", + " ('dconv_up3.4.weight', 33),\n", + " ('dconv_up3.4.bias', 33),\n", + " ('dconv_up2.0.weight', 8470),\n", + " ('dconv_up2.0.bias', 22),\n", + " ('dconv_up2.1.weight', 22),\n", + " ('dconv_up2.1.bias', 22),\n", + " ('dconv_up2.3.weight', 3388),\n", + " ('dconv_up2.3.bias', 22),\n", + " ('dconv_up2.4.weight', 22),\n", + " ('dconv_up2.4.bias', 22),\n", + " ('dconv_up1.0.weight', 3885),\n", + " ('dconv_up1.0.bias', 15),\n", + " ('dconv_up1.1.weight', 15),\n", + " ('dconv_up1.1.bias', 15),\n", + " ('dconv_up1.3.weight', 1575),\n", + " ('dconv_up1.3.bias', 15),\n", + " ('dconv_up1.4.weight', 15),\n", + " ('dconv_up1.4.bias', 15),\n", + " ('conv_last.weight', 30),\n", + " ('conv_last.bias', 2)]" ] }, - "execution_count": 86, + "execution_count": 125, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model = UNet(2)\n", + "#model = DanQ(50, 5)\n", + "\n", + "lst = [(x[0], x[1].numel()) for x in model.named_parameters()]\n", + "#np.sum([x[1] for x in lst])\n", + "count_parameters(model)\n", + "lst" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[('conv1.weight', 4750),\n", + " ('conv1.bias', 50),\n", + " ('conv2.weight', 27500),\n", + " ('conv2.bias', 50),\n", + " ('conv3.weight', 17500),\n", + " ('conv3.bias', 50),\n", + " ('bn1.weight', 50),\n", + " ('bn1.bias', 50),\n", + " ('bn2.weight', 50),\n", + " ('bn2.bias', 50),\n", + " ('bn3.weight', 50),\n", + " ('bn3.bias', 50),\n", + " ('fc1.weight', 4200000),\n", + " ('fc1.bias', 1000),\n", + " ('bn4.weight', 1000),\n", + " ('bn4.bias', 1000),\n", + " ('fc2.weight', 1000000),\n", + " ('fc2.bias', 1000),\n", + " ('bn5.weight', 1000),\n", + " ('bn5.bias', 1000),\n", + " ('fc3.weight', 1000),\n", + " ('fc3.bias', 1)]" + ] + }, + "execution_count": 34, "metadata": {}, "output_type": "execute_result" } @@ -1198,14 +1824,28 @@ "def count_parameters(model):\n", " return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", "\n", - "model = Beagle2()\n", - "model = DanQ(50, 5)\n", + "model = Beagle()\n", + "#model = DanQ(50, 5)\n", "\n", "lst = [(x[0], x[1].numel()) for x in model.named_parameters()]\n", "#np.sum([x[1] for x in lst])\n", "count_parameters(model)\n", "lst" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index 184da7cd..08cba281 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -6,6 +6,8 @@ from wilds.common.grouper import CombinatorialGrouper from wilds.common.metrics.all_metrics import Accuracy +all_chrom_names = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX'] + class EncodeTFBSDataset(WILDSDataset): """ ENCODE-DREAM-wilds dataset of transcription factor binding sites. @@ -33,8 +35,8 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): self._y_size = 1 self._n_classes = 2 - # self._train_chroms = ['chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX'] - self._train_chroms = ['chr2', 'chr9', 'chr11'] + self._train_chroms = ['chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX'] + # self._train_chroms = ['chr2', 'chr9', 'chr11'] self._test_chroms = ['chr1', 'chr8', 'chr21'] self._transcription_factor = 'MAX' self._train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562'] @@ -44,7 +46,7 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): self._all_celltypes = self._train_celltypes + self._val_celltype + self._test_celltype self._metadata_map = {} - self._metadata_map['chr'] = self._all_chroms #['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX'] + self._metadata_map['chr'] = self._all_chroms self._metadata_map['celltype'] = self._all_celltypes # Get the splits @@ -75,11 +77,14 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): self._dnase_allcelltypes = {} for ct in self._all_celltypes: + """ dnase_filename = os.path.join(self._data_dir, '{}_dnase.npz'.format(ct)) dnase_npz_contents = np.load(dnase_filename) self._dnase_allcelltypes[ct] = {} for chrom in self._all_chroms: #self._seq_bp: self._dnase_allcelltypes[ct][chrom] = dnase_npz_contents[chrom] + """ + self._dnase_allcelltypes[ct] = 'DNASE.{}.fc.signal.bigwig' print(ct, time.time() - itime) # Read in metadata dataframe from training+validation data @@ -90,9 +95,11 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): all_df = pd.concat([training_df, val_df]) # Filter by start/stop coordinate if needed (TODO: remove for final version) + """ filter_msk = all_df['start'] >= 0 filter_msk = all_df['start']%1000 == 0 all_df = all_df[filter_msk] + """ pd_list = [] for ct in self._all_celltypes: From 28f2039dafd4ad86444e98b5d566295493277342 Mon Sep 17 00:00:00 2001 From: aikanor Date: Sun, 28 Feb 2021 07:24:05 -0800 Subject: [PATCH 071/244] model changes (TODO: eval check-in) --- examples/models/CNN_genome.py | 55 +- examples/sbox_run_expt.ipynb | 1052 +++++++++++++++------------------ 2 files changed, 503 insertions(+), 604 deletions(-) diff --git a/examples/models/CNN_genome.py b/examples/models/CNN_genome.py index 3efdba20..f1b90d07 100644 --- a/examples/models/CNN_genome.py +++ b/examples/models/CNN_genome.py @@ -42,43 +42,46 @@ def __init__(self, n_class): def forward(self, x): - conv1 = self.dconv_down1(x) - x = self.maxpool(conv1) + # input_size = 12800 + # input_channels = 6 + conv1 = self.dconv_down1(x) # Out: (input_size) x 15 + x = self.maxpool(conv1) # (input_size / 2) x 15 - conv2 = self.dconv_down2(x) - x = self.maxpool(conv2) + conv2 = self.dconv_down2(x) # (input_size / 2) x 22 + x = self.maxpool(conv2) # (input_size / 4) x 22 - conv3 = self.dconv_down3(x) - x = self.maxpool(conv3) + conv3 = self.dconv_down3(x) # (input_size / 4) x 33 + x = self.maxpool(conv3) # (input_size / 8) x 33 - conv4 = self.dconv_down4(x) - x = self.maxpool(conv4) + conv4 = self.dconv_down4(x) # (input_size / 8) x 49 + x = self.maxpool(conv4) # (input_size / 16) x 49 - conv5 = self.dconv_down5(x) - x = self.maxpool(conv5) + conv5 = self.dconv_down5(x) # (input_size / 16) x 73 + x = self.maxpool(conv5) # (input_size / 32) x 73 - x = self.dconv_down6(x) + conv6 = self.dconv_down6(x) # (input_size / 32) x 109 + # Encoder finished. - x = self.upsample(x) - x = torch.cat([x, conv5], dim=1) + x = self.upsample(conv6) # (input_size / 16) x 109 + x = torch.cat([x, conv5], dim=1) # (input_size / 16) x (109 + 73) - x = self.dconv_up5(x) - x = self.upsample(x) - x = torch.cat([x, conv4], dim=1) + x = self.dconv_up5(x) # (input_size / 16) x 73 + x = self.upsample(x) # (input_size / 8) x 73 + x = torch.cat([x, conv4], dim=1) # (input_size / 8) x (73 + 49) - x = self.dconv_up4(x) - x = self.upsample(x) - x = torch.cat([x, conv3], dim=1) + x = self.dconv_up4(x) # (input_size / 8) x 49 + x = self.upsample(x) # (input_size / 4) x 49 + x = torch.cat([x, conv3], dim=1) # (input_size / 4) x (49 + 33) - x = self.dconv_up3(x) - x = self.upsample(x) - x = torch.cat([x, conv2], dim=1) + x = self.dconv_up3(x) # (input_size / 4) x 33 + x = self.upsample(x) # (input_size / 2) x 33 + x = torch.cat([x, conv2], dim=1) # (input_size / 2) x (33 + 22) - x = self.dconv_up2(x) - x = self.upsample(x) - x = torch.cat([x, conv1], dim=1) + x = self.dconv_up2(x) # (input_size / 2) x 22 + x = self.upsample(x) # (input_size) x 22 + x = torch.cat([x, conv1], dim=1) # (input_size) x (22 + 15) - x = self.dconv_up1(x) + x = self.dconv_up1(x) # (input_size) x 15 out = self.conv_last(x) diff --git a/examples/sbox_run_expt.ipynb b/examples/sbox_run_expt.ipynb index 2c56cdd6..66712a29 100644 --- a/examples/sbox_run_expt.ipynb +++ b/examples/sbox_run_expt.ipynb @@ -34,15 +34,7 @@ "cell_type": "code", "execution_count": 1, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:root:The OGB package is out of date. Your version is 1.2.4, while the latest version is 1.2.5.\n" - ] - } - ], + "outputs": [], "source": [ "import os, csv\n", "import time\n", @@ -73,7 +65,7 @@ { "data": { "text/plain": [ - "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" + "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" ] }, "execution_count": 2, @@ -192,6 +184,13 @@ "execution_count": 4, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:root:The OGB package is out of date. Your version is 1.2.4, while the latest version is 1.2.5.\n" + ] + }, { "name": "stdout", "output_type": "stream", @@ -321,26 +320,26 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "chr2 3.7666022777557373\n", - "chr9 5.9439966678619385\n", - "chr11 8.030796766281128\n", - "chr1 11.851332426071167\n", - "chr8 14.106642007827759\n", - "chr21 14.852506160736084\n", - "H1-hESC 14.853845119476318\n", - "HCT116 14.853914022445679\n", - "HeLa-S3 14.853951930999756\n", - "HepG2 14.853987216949463\n", - "K562 14.854026317596436\n", - "A549 14.8540620803833\n", - "GM12878 14.854098796844482\n" + "chr2 3.764267683029175\n", + "chr9 5.914910078048706\n", + "chr11 7.964999675750732\n", + "chr1 11.748822927474976\n", + "chr8 14.01279878616333\n", + "chr21 14.737261772155762\n", + "H1-hESC 14.73790693283081\n", + "HCT116 14.737961292266846\n", + "HeLa-S3 14.737993240356445\n", + "HepG2 14.738024950027466\n", + "K562 14.73805570602417\n", + "A549 14.738086223602295\n", + "GM12878 14.738116979598999\n" ] } ], @@ -420,18 +419,15 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "\"\\nfilter_msk = all_df['start'] >= 0\\nfilter_msk = all_df['start']%1000 == 0\\nall_df = all_df[filter_msk]\\n\"" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "66.32568740844727\n" + ] } ], "source": [ @@ -447,44 +443,6 @@ "print(time.time() - itime)" ] }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [], - "source": [ - "# train_regions_labeled.replace({'U': 0, 'B': 1, 'A': -1})\n", - "# a = \n", - "# np.random.choice(train_regions_labeled.shape[0], size=100000)\n", - "\n", - "v = val_regions_labeled.replace({'U': 0, 'B': 1, 'A': -1})\n", - "# seta = [full_dataset_encode.get_input(x) for x in a]\n", - "# seta[0].shape" - ] - }, - { - "cell_type": "code", - "execution_count": 40, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(array([-7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7]),\n", - " array([ 189, 854, 3579, 11535, 35901, 126629, 621676,\n", - " 7944663, 67689, 13516, 6766, 3332, 3179, 1076,\n", - " 2427]))" - ] - }, - "execution_count": 40, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "np.unique(v[['A549', 'GM12878', 'H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']].sum(axis=1), return_counts=True)" - ] - }, { "cell_type": "code", "execution_count": 59, @@ -547,160 +505,6 @@ "# print(time.time() - itime)" ] }, - { - "cell_type": "code", - "execution_count": 75, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
chrstartstopycelltype
2702470chr116008000H1-hESC
2702471chr116508500H1-hESC
2702472chr117009000H1-hESC
2702473chr117509500H1-hESC
2702474chr1180010000H1-hESC
..................
8843006chr81463632001463634000GM12878
8843007chr81463632501463634500GM12878
8843008chr81463633001463635000GM12878
8843009chr81463633501463635500GM12878
8843010chr81463634001463636000GM12878
\n", - "

131721055 rows × 5 columns

\n", - "
" - ], - "text/plain": [ - " chr start stop y celltype\n", - "2702470 chr11 600 800 0 H1-hESC\n", - "2702471 chr11 650 850 0 H1-hESC\n", - "2702472 chr11 700 900 0 H1-hESC\n", - "2702473 chr11 750 950 0 H1-hESC\n", - "2702474 chr11 800 1000 0 H1-hESC\n", - "... ... ... ... .. ...\n", - "8843006 chr8 146363200 146363400 0 GM12878\n", - "8843007 chr8 146363250 146363450 0 GM12878\n", - "8843008 chr8 146363300 146363500 0 GM12878\n", - "8843009 chr8 146363350 146363550 0 GM12878\n", - "8843010 chr8 146363400 146363600 0 GM12878\n", - "\n", - "[131721055 rows x 5 columns]" - ] - }, - "execution_count": 75, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "metadata_df\n", - "# tc_chr" - ] - }, { "cell_type": "code", "execution_count": 42, @@ -715,7 +519,7 @@ }, { "cell_type": "code", - "execution_count": 68, + "execution_count": 108, "metadata": {}, "outputs": [ { @@ -724,7 +528,7 @@ "686900" ] }, - "execution_count": 68, + "execution_count": 108, "metadata": {}, "output_type": "execute_result" } @@ -736,60 +540,34 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# celltype_msk.sum()\n", - "\n", - "np.unique(_metadata_df['chr'])\n", - "\n", - "# celltype_msk = (_metadata_df['celltype'] == ct)\n", - "# np.where(celltype_msk)[0]" - ] - }, - { - "cell_type": "code", - "execution_count": null, + "execution_count": 88, "metadata": {}, - "outputs": [], - "source": [ - "_metadata_df" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": { - "collapsed": true, - "jupyter": { - "outputs_hidden": true - } - }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "H1-hESC 32.5968804359436\n", - "H1-hESC 33.237690687179565\n", - "H1-hESC 37.01208806037903\n" - ] - }, - { - "ename": "KeyboardInterrupt", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mpos_msk\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0m_metadata_df\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'y'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mct\u001b[0m \u001b[0;32min\u001b[0m \u001b[0m_all_celltypes\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0mcelltype_msk\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0m_metadata_df\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'celltype'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mct\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 8\u001b[0m \u001b[0mneg_ct_msk\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlogical_and\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcelltype_msk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mneg_msk\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0mpos_ct_msk\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlogical_and\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcelltype_msk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpos_msk\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/pandas/core/ops/common.py\u001b[0m in \u001b[0;36mnew_method\u001b[0;34m(self, other)\u001b[0m\n\u001b[1;32m 63\u001b[0m \u001b[0mother\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mitem_from_zerodim\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mother\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 64\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 65\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mmethod\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mother\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 66\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 67\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mnew_method\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/pandas/core/arraylike.py\u001b[0m in \u001b[0;36m__eq__\u001b[0;34m(self, other)\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0munpack_zerodim_and_defer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"__eq__\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__eq__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mother\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 29\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_cmp_method\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mother\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moperator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meq\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 30\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0munpack_zerodim_and_defer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"__ne__\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/pandas/core/series.py\u001b[0m in \u001b[0;36m_cmp_method\u001b[0;34m(self, other, op)\u001b[0m\n\u001b[1;32m 4946\u001b[0m \u001b[0mrvalues\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mextract_array\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mother\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mextract_numpy\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4947\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 4948\u001b[0;31m \u001b[0mres_values\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mops\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcomparison_op\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlvalues\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrvalues\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mop\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4949\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4950\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_construct_result\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mres_values\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mres_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/pandas/core/ops/array_ops.py\u001b[0m in \u001b[0;36mcomparison_op\u001b[0;34m(left, right, op)\u001b[0m\n\u001b[1;32m 241\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 242\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mis_object_dtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlvalues\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 243\u001b[0;31m \u001b[0mres_values\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcomp_method_OBJECT_ARRAY\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mop\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlvalues\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrvalues\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 244\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 245\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/pandas/core/ops/array_ops.py\u001b[0m in \u001b[0;36mcomp_method_OBJECT_ARRAY\u001b[0;34m(op, x, y)\u001b[0m\n\u001b[1;32m 53\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlibops\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvec_compare\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mravel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mravel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mop\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 55\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlibops\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscalar_compare\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mravel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mop\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 56\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 57\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + "H1-hESC 8.10781979560852\n", + "H1-hESC 8.47616195678711\n", + "H1-hESC 9.822284698486328\n", + "HCT116 17.048683881759644\n", + "HCT116 17.41142964363098\n", + "HCT116 18.752415657043457\n", + "HeLa-S3 26.464386463165283\n", + "HeLa-S3 26.860748291015625\n", + "HeLa-S3 28.151614665985107\n", + "HepG2 35.439460039138794\n", + "HepG2 35.83507966995239\n", + "HepG2 37.079824924468994\n", + "K562 44.71583318710327\n", + "K562 45.092923164367676\n", + "K562 46.389798402786255\n", + "A549 53.895429372787476\n", + "A549 54.27841639518738\n", + "A549 55.64506816864014\n", + "GM12878 63.17967939376831\n", + "GM12878 63.545384883880615\n", + "GM12878 64.84915113449097\n" ] } ], @@ -801,34 +579,12 @@ "pos_msk = (_metadata_df['y'] == 1)\n", "for ct in _all_celltypes:\n", " celltype_msk = (_metadata_df['celltype'] == ct)\n", + " print(ct, time.time() - itime)\n", " neg_ct_msk = np.logical_and(celltype_msk, neg_msk)\n", " pos_ct_msk = np.logical_and(celltype_msk, pos_msk)\n", " print(ct, time.time() - itime)\n", " neg_ndces = np.where(neg_ct_msk)[0]\n", " pos_ndces = np.where(pos_ct_msk)[0]\n", - " print(ct, time.time() - itime)\n", - " np.random.seed(42)\n", - " samp_neg_ndces = np.random.choice(neg_ndces, size=len(pos_ndces), replace=False)\n", - " samp_ndces.extend(samp_neg_ndces)\n", - " samp_ndces.extend(pos_ndces)\n", - " print(ct, time.time() - itime)\n", - "_metadata_df = _metadata_df.iloc[samp_ndces, :]" - ] - }, - { - "cell_type": "code", - "execution_count": 35, - "metadata": {}, - "outputs": [], - "source": [ - "# Downsample negatives to balance each celltype\n", - "samp_ndces = []\n", - "itime = time.time()\n", - "for ct in _all_celltypes:\n", - " neg_msk = np.logical_and((_metadata_df['celltype'] == ct), (_metadata_df['y'] == 0))\n", - " pos_msk = np.logical_and((_metadata_df['celltype'] == ct), (_metadata_df['y'] == 1))\n", - " neg_ndces = np.where(neg_msk)[0]\n", - " pos_ndces = np.where(pos_msk)[0]\n", " np.random.seed(42)\n", " samp_neg_ndces = np.random.choice(neg_ndces, size=len(pos_ndces), replace=False)\n", " samp_ndces.extend(samp_neg_ndces)\n", @@ -878,7 +634,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 106, "metadata": { "jupyter": { "source_hidden": true @@ -996,13 +752,19 @@ " metadata_df['y'] = y_array\n", " self._metadata_df = metadata_df[non_ambig_mask]\n", " \n", + " # Downsample negatives to balance each celltype\n", " samp_ndces = []\n", " itime = time.time()\n", - " for ct in self._all_celltypes:\n", - " neg_msk = np.logical_and((self._metadata_df['celltype'] == ct), (self._metadata_df['y'] == 0))\n", - " pos_msk = np.logical_and((self._metadata_df['celltype'] == ct), (self._metadata_df['y'] == 1))\n", - " neg_ndces = np.where(neg_msk)[0]\n", - " pos_ndces = np.where(pos_msk)[0]\n", + " neg_msk = (self._metadata_df['y'] == 0)\n", + " pos_msk = (self._metadata_df['y'] == 1)\n", + " for ct in _all_celltypes:\n", + " celltype_msk = (self._metadata_df['celltype'] == ct)\n", + " print(ct, time.time() - itime)\n", + " neg_ct_msk = np.logical_and(celltype_msk, neg_msk)\n", + " pos_ct_msk = np.logical_and(celltype_msk, pos_msk)\n", + " print(ct, time.time() - itime)\n", + " neg_ndces = np.where(neg_ct_msk)[0]\n", + " pos_ndces = np.where(pos_ct_msk)[0]\n", " np.random.seed(42)\n", " samp_neg_ndces = np.random.choice(neg_ndces, size=len(pos_ndces), replace=False)\n", " samp_ndces.extend(samp_neg_ndces)\n", @@ -1075,38 +837,47 @@ }, { "cell_type": "code", - "execution_count": 20, - "metadata": { - "collapsed": true, - "jupyter": { - "outputs_hidden": true - } - }, + "execution_count": 107, + "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "chr2 3.7390823364257812\n", - "chr9 5.909312963485718\n", - "chr11 8.020122051239014\n", - "chr1 11.871179103851318\n", - "chr8 14.147786140441895\n", - "chr21 14.896430492401123\n", - "H1-hESC 21.391544818878174\n", - "HCT116 27.753155946731567\n", - "HeLa-S3 34.33590316772461\n", - "HepG2 40.81141257286072\n", - "K562 47.39495897293091\n", - "A549 54.245203495025635\n", - "GM12878 60.693068742752075\n", - "H1-hESC 16.79085922241211\n", - "HCT116 33.788668394088745\n", - "HeLa-S3 51.1968936920166\n", - "HepG2 68.32299137115479\n", - "K562 85.74746584892273\n", - "A549 103.05137896537781\n", - "GM12878 120.52022075653076\n" + "chr2 3.962329387664795\n", + "chr9 6.259538888931274\n", + "chr11 8.446826934814453\n", + "chr1 12.49940538406372\n", + "chr8 14.91869592666626\n", + "chr21 15.700694799423218\n", + "H1-hESC 23.95099449157715\n", + "HCT116 31.26502823829651\n", + "HeLa-S3 39.382277488708496\n", + "HepG2 47.24500226974487\n", + "K562 55.079211711883545\n", + "A549 62.405343532562256\n", + "GM12878 70.00356984138489\n", + "H1-hESC 8.160386562347412\n", + "H1-hESC 8.546203374862671\n", + "H1-hESC 9.868412971496582\n", + "HCT116 17.121587991714478\n", + "HCT116 17.524660110473633\n", + "HCT116 18.90956425666809\n", + "HeLa-S3 26.98938488960266\n", + "HeLa-S3 27.376858234405518\n", + "HeLa-S3 28.7989022731781\n", + "HepG2 36.29348182678223\n", + "HepG2 36.668752908706665\n", + "HepG2 38.151512145996094\n", + "K562 45.96789216995239\n", + "K562 46.33995985984802\n", + "K562 47.87280249595642\n", + "A549 55.380892276763916\n", + "A549 55.75924301147461\n", + "A549 57.22686314582825\n", + "GM12878 65.09361720085144\n", + "GM12878 65.50619888305664\n", + "GM12878 66.9196424484253\n" ] } ], @@ -1120,88 +891,29 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 118, "metadata": {}, "outputs": [], "source": [ - "a = np.random.choice(1210796, size=128)\n", - "seta = [full_dataset_encode.get_input(x) for x in a]\n", - "seta[0].shape" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(array([0, 1, 2, 3]), array([2804551, 498433, 34145, 100851]))" - ] - }, - "execution_count": 27, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "np.unique(full_dataset_encode._metadata_df['split'], return_counts=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([0, 0, 0, ..., 1, 1, 1]) torch.Size([60726])\n", - "(array([0, 1]), array([30426, 30300]))\n", - "(array([0, 1, 2, 3]), array([28556, 25350, 1702, 5118]))\n" - ] - } - ], - "source": [ - "full_dataset = copy.deepcopy(full_dataset_encode)\n", - "print(full_dataset._y_array, full_dataset._y_array.shape)\n", - "print(np.unique(full_dataset.y_array.numpy(), return_counts=True))\n", - "print(np.unique(full_dataset._metadata_df['split'], return_counts=True))\n", - "\n", - "#full_dataset._input_array" + "# full_dataset = copy.deepcopy(full_dataset_encode)\n", + "full_dataset = copy.deepcopy(full_dataset_camelyon17)\n", + "# full_dataset_camelyon17.split_dict" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 39, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([0, 0, 0, ..., 0, 0, 0]) torch.Size([5568233])\n", - "(array([0, 1]), array([5537933, 30300]))\n", - "(array([0, 1, 2, 3]), array([2533595, 2163528, 437124, 433986]))\n" - ] - } - ], + "outputs": [], "source": [ - "print(full_dataset._y_array, full_dataset._y_array.shape)\n", - "print(np.unique(full_dataset.y_array.numpy(), return_counts=True))\n", - "print(np.unique(full_dataset._metadata_df['split'], return_counts=True))\n", - "\n", - "#full_dataset._input_array\n", - "\n", - "#full_dataset_encode._seq_bp['chr11'].shape\n", - "full_dataset_encode._dnase_allcelltypes['HCT116']['chr11'].shape" + "a = np.random.choice(1210796, size=128)\n", + "seta = [full_dataset_encode.get_input(x) for x in a]\n", + "seta[0].shape" ] }, { "cell_type": "code", - "execution_count": 68, + "execution_count": 111, "metadata": {}, "outputs": [], "source": [ @@ -1216,22 +928,24 @@ }, { "cell_type": "code", - "execution_count": 118, + "execution_count": 104, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 104, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "# full_dataset = copy.deepcopy(full_dataset_encode)\n", - "full_dataset = copy.deepcopy(full_dataset_camelyon17)\n", - "# full_dataset_camelyon17.split_dict" + "full_dataset" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, { "cell_type": "markdown", "metadata": {}, @@ -1241,26 +955,56 @@ }, { "cell_type": "code", - "execution_count": 120, - "metadata": { - "collapsed": true, - "jupyter": { - "outputs_hidden": true - } - }, + "execution_count": 113, + "metadata": {}, "outputs": [ { - "ename": "ValueError", - "evalue": "I/O operation on closed file", + "name": "stdout", + "output_type": "stream", + "text": [ + "Train data...\n", + " hospital = 0: n = 53425\n", + " hospital = 1: n = 0\n", + " hospital = 2: n = 0\n", + " hospital = 3: n = 116959\n", + " hospital = 4: n = 132052\n", + "Validation (ID) data...\n", + " hospital = 0: n = 6011\n", + " hospital = 1: n = 0\n", + " hospital = 2: n = 0\n", + " hospital = 3: n = 12879\n", + " hospital = 4: n = 14670\n", + "Test data...\n", + " hospital = 0: n = 0\n", + " hospital = 1: n = 0\n", + " hospital = 2: n = 85054\n", + " hospital = 3: n = 0\n", + " hospital = 4: n = 0\n", + "Validation (OOD) data...\n", + " hospital = 0: n = 0\n", + " hospital = 1: n = 34904\n", + " hospital = 2: n = 0\n", + " hospital = 3: n = 0\n", + " hospital = 4: n = 0\n", + "Dout: 2\n" + ] + }, + { + "ename": "RuntimeError", + "evalue": "CUDA error: out of memory", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 57\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[0mlog_grouper\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_grouper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 59\u001b[0;31m \u001b[0mlog_group_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdatasets\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlog_grouper\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlogger\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 60\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 61\u001b[0m \u001b[0;31m## Initialize algorithm\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/wilds/examples/utils.py\u001b[0m in \u001b[0;36mlog_group_data\u001b[0;34m(datasets, grouper, logger)\u001b[0m\n\u001b[1;32m 68\u001b[0m \u001b[0mname\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'name'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 69\u001b[0m \u001b[0mdataset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdataset\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'dataset'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 70\u001b[0;31m \u001b[0mlogger\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwrite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf'{name} data...\\n'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 71\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mgrouper\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[0mlogger\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwrite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf' n = {len(dataset)}\\n'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/wilds/examples/utils.py\u001b[0m in \u001b[0;36mwrite\u001b[0;34m(self, msg)\u001b[0m\n\u001b[1;32m 97\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 98\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mwrite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 99\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconsole\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwrite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 100\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfile\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 101\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfile\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwrite\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/ipykernel/iostream.py\u001b[0m in \u001b[0;36mwrite\u001b[0;34m(self, string)\u001b[0m\n\u001b[1;32m 392\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 393\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpub_thread\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 394\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'I/O operation on closed file'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 395\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 396\u001b[0m \u001b[0;31m# Make sure that we're handling unicode\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mValueError\u001b[0m: I/O operation on closed file" + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 60\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 61\u001b[0m \u001b[0;31m## Initialize algorithm\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 62\u001b[0;31m algorithm = initialize_algorithm(\n\u001b[0m\u001b[1;32m 63\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 64\u001b[0m \u001b[0mdatasets\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdatasets\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/wilds/examples/algorithms/initializer.py\u001b[0m in \u001b[0;36minitialize_algorithm\u001b[0;34m(config, datasets, train_grouper)\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0malgorithm\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m'ERM'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 32\u001b[0;31m algorithm = ERM(\n\u001b[0m\u001b[1;32m 33\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[0md_out\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0md_out\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/wilds/examples/algorithms/ERM.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, config, d_out, grouper, loss, metric, n_train_steps)\u001b[0m\n\u001b[1;32m 6\u001b[0m def __init__(self, config, d_out, grouper, loss,\n\u001b[1;32m 7\u001b[0m metric, n_train_steps):\n\u001b[0;32m----> 8\u001b[0;31m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minitialize_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0md_out\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 9\u001b[0m \u001b[0;31m# initialize module\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m super().__init__(\n", + "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36mto\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 610\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_floating_point\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_blocking\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 611\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 612\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconvert\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 613\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 614\u001b[0m def register_backward_hook(\n", + "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m 357\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 358\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 359\u001b[0;31m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 360\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 361\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m 357\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 358\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 359\u001b[0;31m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 360\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 361\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m 379\u001b[0m \u001b[0;31m# `with torch.no_grad():`\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 380\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mno_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 381\u001b[0;31m \u001b[0mparam_applied\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparam\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 382\u001b[0m \u001b[0mshould_use_set_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparam\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparam_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 383\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mshould_use_set_data\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36mconvert\u001b[0;34m(t)\u001b[0m\n\u001b[1;32m 608\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mconvert_to_format\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m4\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 609\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_floating_point\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_blocking\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmemory_format\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mconvert_to_format\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 610\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_floating_point\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_blocking\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 611\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 612\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconvert\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mRuntimeError\u001b[0m: CUDA error: out of memory" ] } ], @@ -1334,29 +1078,203 @@ }, { "cell_type": "code", - "execution_count": 135, + "execution_count": 91, "metadata": {}, "outputs": [ { "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
chrstartstopycelltypesplit
3831225chr11917992501917994500H1-hESC1
4190052chr12097406002097408000H1-hESC1
7241915chr866306500663067000H1-hESC1
21449377chr238487450384876500H1-hESC0
45876013chr9569770056979000H1-hESC0
.....................
8841297chr81462777501462779501GM128782
8841298chr81462778001462780001GM128782
8841299chr81462778501462780501GM128782
8841300chr81462779001462781001GM128782
8841301chr81462779501462781501GM128782
\n", + "

1210796 rows × 6 columns

\n", + "
" + ], "text/plain": [ - "device(type='cuda', index=0)" + " chr start stop y celltype split\n", + "3831225 chr1 191799250 191799450 0 H1-hESC 1\n", + "4190052 chr1 209740600 209740800 0 H1-hESC 1\n", + "7241915 chr8 66306500 66306700 0 H1-hESC 1\n", + "21449377 chr2 38487450 38487650 0 H1-hESC 0\n", + "45876013 chr9 5697700 5697900 0 H1-hESC 0\n", + "... ... ... ... .. ... ...\n", + "8841297 chr8 146277750 146277950 1 GM12878 2\n", + "8841298 chr8 146277800 146278000 1 GM12878 2\n", + "8841299 chr8 146277850 146278050 1 GM12878 2\n", + "8841300 chr8 146277900 146278100 1 GM12878 2\n", + "8841301 chr8 146277950 146278150 1 GM12878 2\n", + "\n", + "[1210796 rows x 6 columns]" ] }, - "execution_count": 135, + "execution_count": 91, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "algorithm.device\n", + "# algorithm.device\n", + "_metadata_df\n", "# datasets['train']['loader']" ] }, { "cell_type": "code", - "execution_count": 134, + "execution_count": 90, "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'datasets' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdatasets\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'train'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'loader'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmetadata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mNameError\u001b[0m: name 'datasets' is not defined" + ] + } + ], + "source": [ + "for batch in datasets['train']['loader']:\n", + " x, y_true, metadata = batch\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": 134, + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + } + }, "outputs": [ { "ename": "RuntimeError", @@ -1387,17 +1305,6 @@ "algorithm.model(x.to(algorithm.device))" ] }, - { - "cell_type": "code", - "execution_count": 131, - "metadata": {}, - "outputs": [], - "source": [ - "for batch in datasets['train']['loader']:\n", - " x, y_true, metadata = batch\n", - " break" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -1580,13 +1487,13 @@ }, { "cell_type": "code", - "execution_count": 124, + "execution_count": 100, "metadata": {}, "outputs": [], "source": [ "def double_conv(in_channels, out_channels): \n", " return nn.Sequential(\n", - " nn.Conv1d(in_channels, out_channels, 7, padding=3), \n", + " nn.Conv1d(in_channels, out_channels, 7, padding=2), \n", " nn.BatchNorm1d(out_channels), \n", " nn.ReLU(inplace=True),\n", " nn.Conv1d(out_channels, out_channels, 7, padding=3), \n", @@ -1665,110 +1572,16 @@ }, { "cell_type": "code", - "execution_count": 125, - "metadata": { - "collapsed": true, - "jupyter": { - "outputs_hidden": true - } - }, + "execution_count": 101, + "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "[('dconv_down1.0.weight', 630),\n", - " ('dconv_down1.0.bias', 15),\n", - " ('dconv_down1.1.weight', 15),\n", - " ('dconv_down1.1.bias', 15),\n", - " ('dconv_down1.3.weight', 1575),\n", - " ('dconv_down1.3.bias', 15),\n", - " ('dconv_down1.4.weight', 15),\n", - " ('dconv_down1.4.bias', 15),\n", - " ('dconv_down2.0.weight', 2310),\n", - " ('dconv_down2.0.bias', 22),\n", - " ('dconv_down2.1.weight', 22),\n", - " ('dconv_down2.1.bias', 22),\n", - " ('dconv_down2.3.weight', 3388),\n", - " ('dconv_down2.3.bias', 22),\n", - " ('dconv_down2.4.weight', 22),\n", - " ('dconv_down2.4.bias', 22),\n", - " ('dconv_down3.0.weight', 5082),\n", - " ('dconv_down3.0.bias', 33),\n", - " ('dconv_down3.1.weight', 33),\n", - " ('dconv_down3.1.bias', 33),\n", - " ('dconv_down3.3.weight', 7623),\n", - " ('dconv_down3.3.bias', 33),\n", - " ('dconv_down3.4.weight', 33),\n", - " ('dconv_down3.4.bias', 33),\n", - " ('dconv_down4.0.weight', 11319),\n", - " ('dconv_down4.0.bias', 49),\n", - " ('dconv_down4.1.weight', 49),\n", - " ('dconv_down4.1.bias', 49),\n", - " ('dconv_down4.3.weight', 16807),\n", - " ('dconv_down4.3.bias', 49),\n", - " ('dconv_down4.4.weight', 49),\n", - " ('dconv_down4.4.bias', 49),\n", - " ('dconv_down5.0.weight', 25039),\n", - " ('dconv_down5.0.bias', 73),\n", - " ('dconv_down5.1.weight', 73),\n", - " ('dconv_down5.1.bias', 73),\n", - " ('dconv_down5.3.weight', 37303),\n", - " ('dconv_down5.3.bias', 73),\n", - " ('dconv_down5.4.weight', 73),\n", - " ('dconv_down5.4.bias', 73),\n", - " ('dconv_down6.0.weight', 55699),\n", - " ('dconv_down6.0.bias', 109),\n", - " ('dconv_down6.1.weight', 109),\n", - " ('dconv_down6.1.bias', 109),\n", - " ('dconv_down6.3.weight', 83167),\n", - " ('dconv_down6.3.bias', 109),\n", - " ('dconv_down6.4.weight', 109),\n", - " ('dconv_down6.4.bias', 109),\n", - " ('dconv_up5.0.weight', 93002),\n", - " ('dconv_up5.0.bias', 73),\n", - " ('dconv_up5.1.weight', 73),\n", - " ('dconv_up5.1.bias', 73),\n", - " ('dconv_up5.3.weight', 37303),\n", - " ('dconv_up5.3.bias', 73),\n", - " ('dconv_up5.4.weight', 73),\n", - " ('dconv_up5.4.bias', 73),\n", - " ('dconv_up4.0.weight', 41846),\n", - " ('dconv_up4.0.bias', 49),\n", - " ('dconv_up4.1.weight', 49),\n", - " ('dconv_up4.1.bias', 49),\n", - " ('dconv_up4.3.weight', 16807),\n", - " ('dconv_up4.3.bias', 49),\n", - " ('dconv_up4.4.weight', 49),\n", - " ('dconv_up4.4.bias', 49),\n", - " ('dconv_up3.0.weight', 18942),\n", - " ('dconv_up3.0.bias', 33),\n", - " ('dconv_up3.1.weight', 33),\n", - " ('dconv_up3.1.bias', 33),\n", - " ('dconv_up3.3.weight', 7623),\n", - " ('dconv_up3.3.bias', 33),\n", - " ('dconv_up3.4.weight', 33),\n", - " ('dconv_up3.4.bias', 33),\n", - " ('dconv_up2.0.weight', 8470),\n", - " ('dconv_up2.0.bias', 22),\n", - " ('dconv_up2.1.weight', 22),\n", - " ('dconv_up2.1.bias', 22),\n", - " ('dconv_up2.3.weight', 3388),\n", - " ('dconv_up2.3.bias', 22),\n", - " ('dconv_up2.4.weight', 22),\n", - " ('dconv_up2.4.bias', 22),\n", - " ('dconv_up1.0.weight', 3885),\n", - " ('dconv_up1.0.bias', 15),\n", - " ('dconv_up1.1.weight', 15),\n", - " ('dconv_up1.1.bias', 15),\n", - " ('dconv_up1.3.weight', 1575),\n", - " ('dconv_up1.3.bias', 15),\n", - " ('dconv_up1.4.weight', 15),\n", - " ('dconv_up1.4.bias', 15),\n", - " ('conv_last.weight', 30),\n", - " ('conv_last.bias', 2)]" + "485773" ] }, - "execution_count": 125, + "execution_count": 101, "metadata": {}, "output_type": "execute_result" } @@ -1779,58 +1592,141 @@ "\n", "lst = [(x[0], x[1].numel()) for x in model.named_parameters()]\n", "#np.sum([x[1] for x in lst])\n", - "count_parameters(model)\n", - "lst" + "count_parameters(model)" ] }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 102, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "[('conv1.weight', 4750),\n", - " ('conv1.bias', 50),\n", - " ('conv2.weight', 27500),\n", - " ('conv2.bias', 50),\n", - " ('conv3.weight', 17500),\n", - " ('conv3.bias', 50),\n", - " ('bn1.weight', 50),\n", - " ('bn1.bias', 50),\n", - " ('bn2.weight', 50),\n", - " ('bn2.bias', 50),\n", - " ('bn3.weight', 50),\n", - " ('bn3.bias', 50),\n", - " ('fc1.weight', 4200000),\n", - " ('fc1.bias', 1000),\n", - " ('bn4.weight', 1000),\n", - " ('bn4.bias', 1000),\n", - " ('fc2.weight', 1000000),\n", - " ('fc2.bias', 1000),\n", - " ('bn5.weight', 1000),\n", - " ('bn5.bias', 1000),\n", - " ('fc3.weight', 1000),\n", - " ('fc3.bias', 1)]" + "UNet(\n", + " (dconv_down1): Sequential(\n", + " (0): Conv1d(6, 15, kernel_size=(7,), stride=(1,), padding=(2,))\n", + " (1): BatchNorm1d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU(inplace=True)\n", + " (3): Conv1d(15, 15, kernel_size=(7,), stride=(1,), padding=(3,))\n", + " (4): BatchNorm1d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (5): ReLU(inplace=True)\n", + " )\n", + " (dconv_down2): Sequential(\n", + " (0): Conv1d(15, 22, kernel_size=(7,), stride=(1,), padding=(2,))\n", + " (1): BatchNorm1d(22, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU(inplace=True)\n", + " (3): Conv1d(22, 22, kernel_size=(7,), stride=(1,), padding=(3,))\n", + " (4): BatchNorm1d(22, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (5): ReLU(inplace=True)\n", + " )\n", + " (dconv_down3): Sequential(\n", + " (0): Conv1d(22, 33, kernel_size=(7,), stride=(1,), padding=(2,))\n", + " (1): BatchNorm1d(33, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU(inplace=True)\n", + " (3): Conv1d(33, 33, kernel_size=(7,), stride=(1,), padding=(3,))\n", + " (4): BatchNorm1d(33, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (5): ReLU(inplace=True)\n", + " )\n", + " (dconv_down4): Sequential(\n", + " (0): Conv1d(33, 49, kernel_size=(7,), stride=(1,), padding=(2,))\n", + " (1): BatchNorm1d(49, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU(inplace=True)\n", + " (3): Conv1d(49, 49, kernel_size=(7,), stride=(1,), padding=(3,))\n", + " (4): BatchNorm1d(49, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (5): ReLU(inplace=True)\n", + " )\n", + " (dconv_down5): Sequential(\n", + " (0): Conv1d(49, 73, kernel_size=(7,), stride=(1,), padding=(2,))\n", + " (1): BatchNorm1d(73, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU(inplace=True)\n", + " (3): Conv1d(73, 73, kernel_size=(7,), stride=(1,), padding=(3,))\n", + " (4): BatchNorm1d(73, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (5): ReLU(inplace=True)\n", + " )\n", + " (dconv_down6): Sequential(\n", + " (0): Conv1d(73, 109, kernel_size=(7,), stride=(1,), padding=(2,))\n", + " (1): BatchNorm1d(109, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU(inplace=True)\n", + " (3): Conv1d(109, 109, kernel_size=(7,), stride=(1,), padding=(3,))\n", + " (4): BatchNorm1d(109, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (5): ReLU(inplace=True)\n", + " )\n", + " (maxpool): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", + " (upsample): Upsample(scale_factor=2.0, mode=bilinear)\n", + " (dconv_up5): Sequential(\n", + " (0): Conv1d(182, 73, kernel_size=(7,), stride=(1,), padding=(2,))\n", + " (1): BatchNorm1d(73, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU(inplace=True)\n", + " (3): Conv1d(73, 73, kernel_size=(7,), stride=(1,), padding=(3,))\n", + " (4): BatchNorm1d(73, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (5): ReLU(inplace=True)\n", + " )\n", + " (dconv_up4): Sequential(\n", + " (0): Conv1d(122, 49, kernel_size=(7,), stride=(1,), padding=(2,))\n", + " (1): BatchNorm1d(49, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU(inplace=True)\n", + " (3): Conv1d(49, 49, kernel_size=(7,), stride=(1,), padding=(3,))\n", + " (4): BatchNorm1d(49, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (5): ReLU(inplace=True)\n", + " )\n", + " (dconv_up3): Sequential(\n", + " (0): Conv1d(82, 33, kernel_size=(7,), stride=(1,), padding=(2,))\n", + " (1): BatchNorm1d(33, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU(inplace=True)\n", + " (3): Conv1d(33, 33, kernel_size=(7,), stride=(1,), padding=(3,))\n", + " (4): BatchNorm1d(33, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (5): ReLU(inplace=True)\n", + " )\n", + " (dconv_up2): Sequential(\n", + " (0): Conv1d(55, 22, kernel_size=(7,), stride=(1,), padding=(2,))\n", + " (1): BatchNorm1d(22, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU(inplace=True)\n", + " (3): Conv1d(22, 22, kernel_size=(7,), stride=(1,), padding=(3,))\n", + " (4): BatchNorm1d(22, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (5): ReLU(inplace=True)\n", + " )\n", + " (dconv_up1): Sequential(\n", + " (0): Conv1d(37, 15, kernel_size=(7,), stride=(1,), padding=(2,))\n", + " (1): BatchNorm1d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (2): ReLU(inplace=True)\n", + " (3): Conv1d(15, 15, kernel_size=(7,), stride=(1,), padding=(3,))\n", + " (4): BatchNorm1d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (5): ReLU(inplace=True)\n", + " )\n", + " (conv_last): Conv2d(15, 2, kernel_size=(1, 1), stride=(1, 1))\n", + ")" ] }, - "execution_count": 34, + "execution_count": 102, "metadata": {}, "output_type": "execute_result" } ], + "source": [ + "model" + ] + }, + { + "cell_type": "code", + "execution_count": 96, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'Beagle' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mp\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparameters\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrequires_grad\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mBeagle\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0;31m#model = DanQ(50, 5)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mNameError\u001b[0m: name 'Beagle' is not defined" + ] + } + ], "source": [ "def count_parameters(model):\n", - " return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", - "\n", - "model = Beagle()\n", - "#model = DanQ(50, 5)\n", - "\n", - "lst = [(x[0], x[1].numel()) for x in model.named_parameters()]\n", - "#np.sum([x[1] for x in lst])\n", - "count_parameters(model)\n", - "lst" + " return sum(p.numel() for p in model.parameters() if p.requires_grad)" ] }, { From 42794b06921e683f72c98442c8625b05bacb8827 Mon Sep 17 00:00:00 2001 From: aikanor Date: Mon, 1 Mar 2021 08:08:05 -0800 Subject: [PATCH 072/244] check in training/eval --- examples/models/CNN_genome.py | 18 +- examples/sbox_run_expt.ipynb | 1733 +++++++++++++++++++------- wilds/datasets/encodetfbs_dataset.py | 77 +- 3 files changed, 1344 insertions(+), 484 deletions(-) diff --git a/examples/models/CNN_genome.py b/examples/models/CNN_genome.py index f1b90d07..147f8c9e 100644 --- a/examples/models/CNN_genome.py +++ b/examples/models/CNN_genome.py @@ -6,7 +6,14 @@ -def double_conv(in_channels, out_channels): +def single_conv(in_channels, out_channels): + return nn.Sequential( + nn.Conv1d(in_channels, out_channels, 7, padding=3), + nn.BatchNorm1d(out_channels), + nn.ReLU(inplace=True) + ) + +def double_conv(in_channels, out_channels): return nn.Sequential( nn.Conv1d(in_channels, out_channels, 7, padding=3), nn.BatchNorm1d(out_channels), @@ -19,10 +26,10 @@ def double_conv(in_channels, out_channels): class UNet(nn.Module): - def __init__(self, n_class): + def __init__(self, n_class, n_channels_in=6): super().__init__() - self.dconv_down1 = double_conv(6, 15) + self.dconv_down1 = double_conv(n_channels_in, 15) self.dconv_down2 = double_conv(15, 22) self.dconv_down3 = double_conv(22, 33) self.dconv_down4 = double_conv(33, 49) @@ -30,7 +37,8 @@ def __init__(self, n_class): self.dconv_down6 = double_conv(73, 109) self.maxpool = nn.MaxPool1d(2) - self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + self.conv_middle = single_conv(109, 109) self.dconv_up5 = double_conv(73 + 109, 73) self.dconv_up4 = double_conv(49 + 73, 49) @@ -60,6 +68,8 @@ def forward(self, x): x = self.maxpool(conv5) # (input_size / 32) x 73 conv6 = self.dconv_down6(x) # (input_size / 32) x 109 + # conv6 = self.conv_middle(conv6) # Optional: convolution here. + # Encoder finished. x = self.upsample(conv6) # (input_size / 16) x 109 diff --git a/examples/sbox_run_expt.ipynb b/examples/sbox_run_expt.ipynb index 66712a29..06440dc6 100644 --- a/examples/sbox_run_expt.ipynb +++ b/examples/sbox_run_expt.ipynb @@ -34,7 +34,33 @@ "cell_type": "code", "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "163 µs ± 343 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n" + ] + } + ], + "source": [ + "import pyBigWig\n", + "# %timeit bw = pyBigWig.open(\"/users/abalsubr/wilds/examples/data/encode-tfbs_v1.0/DNASE.K562.fc.signal.bigwig\")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:root:The OGB package is out of date. Your version is 1.2.4, while the latest version is 1.2.5.\n" + ] + } + ], "source": [ "import os, csv\n", "import time\n", @@ -59,16 +85,16 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" + "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" ] }, - "execution_count": 2, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -76,7 +102,7 @@ "source": [ "''' set default hyperparams in default_hyperparams.py '''\n", "parser = argparse.ArgumentParser()\n", - "\n", + "CombinatorialGrouper\n", "# Required arguments\n", "parser.add_argument('-d', '--dataset', choices=supported.datasets, required=True)\n", "parser.add_argument('--algorithm', required=True, choices=supported.algorithms)\n", @@ -163,7 +189,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -175,27 +201,20 @@ "config_encode = parser.parse_args(argstr_encode.split())\n", "config_encode = populate_defaults(config_encode)\n", "\n", - "config = config_camelyon\n", - "# config = config_encode" + "# config = config_camelyon\n", + "config = config_encode\n" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:root:The OGB package is out of date. Your version is 1.2.4, while the latest version is 1.2.5.\n" - ] - }, { "name": "stdout", "output_type": "stream", "text": [ - "Dataset: camelyon17\n", + "Dataset: encode-tfbs\n", "Algorithm: ERM\n", "Root dir: data\n", "Split scheme: official\n", @@ -207,26 +226,26 @@ "Uniform over groups: False\n", "Distinct groups: None\n", "N groups per batch: 2\n", - "Batch size: 32\n", + "Batch size: 64\n", "Eval loader: standard\n", - "Model: densenet121\n", + "Model: leopard\n", "Model kwargs: {'pretrained': False}\n", - "Train transform: image_base\n", - "Eval transform: image_base\n", - "Target resolution: (224, 224)\n", + "Train transform: None\n", + "Eval transform: None\n", + "Target resolution: None\n", "Resize scale: None\n", "Max token length: None\n", "Loss function: cross_entropy\n", - "Groupby fields: ['hospital']\n", + "Groupby fields: ['celltype', 'y']\n", "Group dro step size: None\n", - "Coral penalty weight: 0.1\n", - "Irm lambda: 1.0\n", + "Coral penalty weight: None\n", + "Irm lambda: None\n", "Irm penalty anneal iters: None\n", "Algo log metric: accuracy\n", "Val metric: acc_avg\n", "Val metric decreasing: False\n", "N epochs: 5\n", - "Optimizer: SGD\n", + "Optimizer: Adam\n", "Lr: 0.001\n", "Weight decay: 0.01\n", "Max grad norm: None\n", @@ -250,7 +269,42 @@ "Use wandb: False\n", "Progress bar: False\n", "Resume: False\n", - "\n" + "\n", + "chr2 3.6633927822113037\n", + "chr3 6.7115819454193115\n", + "chr4 9.648637771606445\n", + "chr5 12.439441919326782\n", + "chr6 15.091757774353027\n", + "chr7 17.542895555496216\n", + "chr9 19.707583904266357\n", + "chr10 21.79905652999878\n", + "chr11 23.86957049369812\n", + "chr12 25.918642044067383\n", + "chr13 27.675577402114868\n", + "chr14 29.3148353099823\n", + "chr15 30.881144046783447\n", + "chr16 32.271193504333496\n", + "chr17 33.51785063743591\n", + "chr18 34.72123050689697\n", + "chr19 35.627156257629395\n", + "chr20 36.59872794151306\n", + "chr22 37.37847852706909\n", + "chrX 39.77280807495117\n", + "chr1 43.60475468635559\n", + "chr8 45.86070203781128\n", + "chr21 46.59553360939026\n" + ] + }, + { + "ename": "NameError", + "evalue": "name '_all_celltypes' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0;31m# Data\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 26\u001b[0;31m full_dataset = supported.datasets[config.dataset](\n\u001b[0m\u001b[1;32m 27\u001b[0m \u001b[0mroot_dir\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mroot_dir\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0mdownload\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdownload\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/wilds/wilds/datasets/encodetfbs_dataset.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, root_dir, download, split_scheme)\u001b[0m\n\u001b[1;32m 101\u001b[0m \u001b[0;31m# Get the y values, and remove ambiguous labels by default.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 102\u001b[0m \u001b[0mpd_list\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 103\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mct\u001b[0m \u001b[0;32min\u001b[0m \u001b[0m_all_celltypes\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 104\u001b[0m \u001b[0mtc_chr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mall_df\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'chr'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'start'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'stop'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mct\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 105\u001b[0m \u001b[0mtc_chr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcolumns\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m'chr'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'start'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'stop'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'y'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mNameError\u001b[0m: name '_all_celltypes' is not defined" ] } ], @@ -300,15 +354,47 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 7, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "import copy\n", "full_dataset_camelyon17 = copy.deepcopy(full_dataset)\n", "\n", "# supported.datasets[config_encode.dataset]\n", - "# print(config_camelyon.train_transform, config_encode.train_transform)" + "# print(config_camelyon.train_transform, config_encode.train_transform)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'full_dataset' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mfull_dataset\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mNameError\u001b[0m: name 'full_dataset' is not defined" + ] + } + ], + "source": [ + "full_dataset" ] }, { @@ -320,26 +406,26 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "chr2 3.764267683029175\n", - "chr9 5.914910078048706\n", - "chr11 7.964999675750732\n", - "chr1 11.748822927474976\n", - "chr8 14.01279878616333\n", - "chr21 14.737261772155762\n", - "H1-hESC 14.73790693283081\n", - "HCT116 14.737961292266846\n", - "HeLa-S3 14.737993240356445\n", - "HepG2 14.738024950027466\n", - "K562 14.73805570602417\n", - "A549 14.738086223602295\n", - "GM12878 14.738116979598999\n" + "chr2 3.90254282951355\n", + "chr9 6.149690628051758\n", + "chr11 8.327073097229004\n", + "chr1 12.291624546051025\n", + "chr8 14.624409675598145\n", + "chr21 15.413429021835327\n", + "H1-hESC 15.415196895599365\n", + "HCT116 15.415941953659058\n", + "HeLa-S3 15.416455030441284\n", + "HepG2 15.417592763900757\n", + "K562 15.418397426605225\n", + "A549 15.41891360282898\n", + "GM12878 15.419732332229614\n" ] } ], @@ -405,28 +491,32 @@ " print(chrom, time.time() - itime)\n", "\n", "_dnase_allcelltypes = {}\n", + "ct = 'avg'\n", + "dnase_avg_bw_path = os.path.join(_data_dir, 'Leopard_dnase/{}.bigwig'.format(ct))\n", + "_dnase_allcelltypes[ct] = pyBigWig.open(dnase_avg_bw_path)\n", "for ct in _all_celltypes:\n", " \"\"\"\n", - " dnase_filename = os.path.join(_data_dir, '{}_dnase.npz'.format(ct))\n", + " dnase_filename = os.path.join(self._data_dir, '{}_dnase.npz'.format(ct))\n", " dnase_npz_contents = np.load(dnase_filename)\n", - " _dnase_allcelltypes[ct] = {}\n", - " for chrom in _all_chroms: #_seq_bp:\n", - " _dnase_allcelltypes[ct][chrom] = dnase_npz_contents[chrom]\n", + " self._dnase_allcelltypes[ct] = {}\n", + " for chrom in self._all_chroms: #self._seq_bp:\n", + " self._dnase_allcelltypes[ct][chrom] = dnase_npz_contents[chrom]\n", " \"\"\"\n", - " _dnase_allcelltypes[ct] = 'DNASE.{}.fc.signal.bigwig'\n", + " dnase_bw_path = os.path.join(_data_dir, 'Leopard_dnase/{}.bigwig'.format(ct))\n", + " _dnase_allcelltypes[ct] = pyBigWig.open(dnase_bw_path)\n", " print(ct, time.time() - itime)" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "66.32568740844727\n" + "74.06488299369812\n" ] } ], @@ -445,33 +535,33 @@ }, { "cell_type": "code", - "execution_count": 59, + "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - ":12: SettingWithCopyWarning: \n", + ":8: SettingWithCopyWarning: \n", "A value is trying to be set on a copy of a slice from a DataFrame.\n", "Try using .loc[row_indexer,col_indexer] = value instead\n", "\n", "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", - " tc_chr['y'] = y_array\n" + " tc_chr['y'] = tc_chr['y'].replace({'U': 0, 'B': 1, 'A': 0.5}).values\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "11.363114833831787\n", - "21.872379302978516\n", - "32.51760506629944\n", - "42.88175559043884\n", - "53.35902285575867\n", - "63.94557332992554\n", - "74.44822382926941\n", - "92.237633228302\n" + "13.203907012939453\n", + "22.746795415878296\n", + "32.076903104782104\n", + "41.43525505065918\n", + "51.37267017364502\n", + "61.07364773750305\n", + "71.10506510734558\n", + "100.72125267982483\n" ] } ], @@ -483,12 +573,12 @@ "for ct in _all_celltypes:\n", " tc_chr = all_df[['chr', 'start', 'stop', ct]]\n", " tc_chr.columns = ['chr', 'start', 'stop', 'y']\n", - " y_array = tc_chr['y'].replace({'U': 0, 'B': 1, 'A': -1}).values\n", + " tc_chr['y'] = tc_chr['y'].replace({'U': 0, 'B': 1, 'A': 0.5}).values\n", " \n", - " # Now filter out ambiguous labels\n", - " non_ambig_mask = (y_array != -1)\n", - " tc_chr['y'] = y_array\n", - " tc_chr = tc_chr[non_ambig_mask]\n", + " # # Now filter out ambiguous labels\n", + " # non_ambig_mask = (y_array != -1)\n", + " # tc_chr['y'] = y_array\n", + " # tc_chr = tc_chr[non_ambig_mask]\n", " \n", " tc_chr.insert(len(tc_chr.columns), 'celltype', ct)\n", " pd_list.append(tc_chr)\n", @@ -497,101 +587,440 @@ "\n", "print(time.time() - itime)\n", "\n", - "# y_array = metadata_df['y'].replace({'U': 0, 'B': 1, 'A': -1}).values\n", - "# non_ambig_mask = (y_array != -1)\n", - "# metadata_df['y'] = y_array\n", - "# _metadata_df = metadata_df[non_ambig_mask]\n", + "_metadata_df = metadata_df\n" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "H1-hESC 8.80908489227295\n", + "H1-hESC 12.474784135818481\n", + "H1-hESC 15.258162498474121\n", + "HCT116 23.023517370224\n", + "HCT116 25.439095735549927\n", + "HCT116 27.21790099143982\n", + "HeLa-S3 34.81065845489502\n", + "HeLa-S3 35.2776780128479\n", + "HeLa-S3 36.60090255737305\n", + "HepG2 44.36072087287903\n", + "HepG2 44.74501991271973\n", + "HepG2 46.02569603919983\n", + "K562 53.825233697891235\n", + "K562 54.182188749313354\n", + "K562 55.44522547721863\n", + "A549 62.980581283569336\n", + "A549 63.34522008895874\n", + "A549 64.59721446037292\n", + "GM12878 72.41460752487183\n", + "GM12878 72.7955391407013\n", + "GM12878 74.05369997024536\n" + ] + } + ], + "source": [ + "# np.unique(_metadata_df['y'])\n", "\n", - "# print(time.time() - itime)" + "# Downsample negatives to balance each celltype\n", + "samp_ndces = []\n", + "itime = time.time()\n", + "neg_msk = (_metadata_df['y'] == 0)\n", + "pos_msk = (_metadata_df['y'] != 0)\n", + "for ct in _all_celltypes:\n", + " celltype_msk = (_metadata_df['celltype'] == ct)\n", + " print(ct, time.time() - itime)\n", + " neg_ct_msk = np.logical_and(celltype_msk, neg_msk)\n", + " pos_ct_msk = np.logical_and(celltype_msk, pos_msk)\n", + " print(ct, time.time() - itime)\n", + " neg_ndces = np.where(neg_ct_msk)[0]\n", + " pos_ndces = np.where(pos_ct_msk)[0]\n", + " np.random.seed(42)\n", + " samp_neg_ndces = np.random.choice(neg_ndces, size=len(pos_ndces), replace=False)\n", + " samp_ndces.extend(samp_neg_ndces)\n", + " samp_ndces.extend(pos_ndces)\n", + " print(ct, time.time() - itime)\n", + "_metadata_df = _metadata_df.iloc[samp_ndces, :]" ] }, { "cell_type": "code", - "execution_count": 42, + "execution_count": 145, "metadata": {}, "outputs": [], "source": [ - "# window_size = 12800\n", - "# window_interval = window_size/2\n", - "# trl_mask = (train_regions_labeled['start']%window_interval == 0)\n", - "# train_regions_labeled[trl_mask]" + "def get_random_label_vec(metadata_df, output_size=128):\n", + " # Sample a positively labeled region at random\n", + " pos_mdf = metadata_df[metadata_df['y'] == 1] #.iloc[ metadata_df['chr'] == s['chr'], : ]\n", + " pos_seed_region = pos_mdf.iloc[np.random.randint(pos_mdf.shape[0])]\n", + "\n", + " # Extract regions from this chromosome in this celltype, to get a window of labels from\n", + " chr_msk = np.array(metadata_df['chr']) == pos_seed_region['chr']\n", + " ct_msk = np.array(metadata_df['celltype']) == pos_seed_region['celltype']\n", + " mdf = metadata_df[chr_msk & ct_msk]\n", + "\n", + " # Get labels\n", + " start_ndx = np.where(mdf['start'] == pos_seed_region['start'])[0][0]\n", + " y_label_vec = mdf.iloc[start_ndx:start_ndx+output_size, :]['y']" + ] + }, + { + "cell_type": "code", + "execution_count": 146, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
chrstartstopycelltype
5937924chr86008000.0HepG2
5937925chr86508500.0HepG2
5937926chr87009000.0HepG2
5937927chr87509500.0HepG2
5937928chr880010000.0HepG2
..................
8843006chr81463632001463634000.0HepG2
8843007chr81463632501463634500.0HepG2
8843008chr81463633001463635000.0HepG2
8843009chr81463633501463635500.0HepG2
8843010chr81463634001463636000.0HepG2
\n", + "

2905087 rows × 5 columns

\n", + "
" + ], + "text/plain": [ + " chr start stop y celltype\n", + "5937924 chr8 600 800 0.0 HepG2\n", + "5937925 chr8 650 850 0.0 HepG2\n", + "5937926 chr8 700 900 0.0 HepG2\n", + "5937927 chr8 750 950 0.0 HepG2\n", + "5937928 chr8 800 1000 0.0 HepG2\n", + "... ... ... ... ... ...\n", + "8843006 chr8 146363200 146363400 0.0 HepG2\n", + "8843007 chr8 146363250 146363450 0.0 HepG2\n", + "8843008 chr8 146363300 146363500 0.0 HepG2\n", + "8843009 chr8 146363350 146363550 0.0 HepG2\n", + "8843010 chr8 146363400 146363600 0.0 HepG2\n", + "\n", + "[2905087 rows x 5 columns]" + ] + }, + "execution_count": 146, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 154, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "8809571 1.0\n", + "8809572 1.0\n", + "8809573 1.0\n", + "8809574 1.0\n", + "8809575 1.0\n", + "8809576 0.5\n", + "8809577 0.5\n", + "8809578 0.5\n", + "8809579 1.0\n", + "8809580 1.0\n", + "8809581 1.0\n", + "8809582 1.0\n", + "8809583 1.0\n", + "8809584 1.0\n", + "8809585 0.5\n", + "8809586 0.5\n", + "8809587 0.0\n", + "8809588 0.0\n", + "8809589 0.0\n", + "8809590 0.0\n", + "8809591 0.0\n", + "8809592 0.0\n", + "8809593 0.0\n", + "8809594 0.0\n", + "8809595 0.0\n", + "8809596 0.0\n", + "8809597 0.0\n", + "8809598 0.0\n", + "8809599 0.0\n", + "8809600 0.0\n", + "8809601 0.0\n", + "8809602 0.0\n", + "Name: y, dtype: float64" + ] + }, + "execution_count": 154, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 107, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([0.])" + ] + }, + "execution_count": 107, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.unique(mdf[:256+start_bin]['y'])" + ] + }, + { + "cell_type": "code", + "execution_count": 150, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "chr chr8\n", + "start 144691450\n", + "stop 144691650\n", + "y 1.0\n", + "celltype HepG2\n", + "Name: 8809571, dtype: object" + ] + }, + "execution_count": 150, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pos_seed_region" ] }, { "cell_type": "code", - "execution_count": 108, + "execution_count": 98, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "686900" + "array([ 600, 600, 600, ..., 135005900, 135005900,\n", + " 135005900])" ] }, - "execution_count": 108, + "execution_count": 98, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "(metadata_df['y'] == 1).sum()\n", - "# pd_list[0][non_ambig_mask]" + "# arr = metadata_df[mdf_msk]['start']\n", + "#arr == \n", + "np.sort(arr)" ] }, { "cell_type": "code", - "execution_count": 88, + "execution_count": 69, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "H1-hESC 8.10781979560852\n", - "H1-hESC 8.47616195678711\n", - "H1-hESC 9.822284698486328\n", - "HCT116 17.048683881759644\n", - "HCT116 17.41142964363098\n", - "HCT116 18.752415657043457\n", - "HeLa-S3 26.464386463165283\n", - "HeLa-S3 26.860748291015625\n", - "HeLa-S3 28.151614665985107\n", - "HepG2 35.439460039138794\n", - "HepG2 35.83507966995239\n", - "HepG2 37.079824924468994\n", - "K562 44.71583318710327\n", - "K562 45.092923164367676\n", - "K562 46.389798402786255\n", - "A549 53.895429372787476\n", - "A549 54.27841639518738\n", - "A549 55.64506816864014\n", - "GM12878 63.17967939376831\n", - "GM12878 63.545384883880615\n", - "GM12878 64.84915113449097\n" + "116.39193439483643\n" ] } ], "source": [ - "# Downsample negatives to balance each celltype\n", - "samp_ndces = []\n", "itime = time.time()\n", - "neg_msk = (_metadata_df['y'] == 0)\n", - "pos_msk = (_metadata_df['y'] == 1)\n", - "for ct in _all_celltypes:\n", - " celltype_msk = (_metadata_df['celltype'] == ct)\n", - " print(ct, time.time() - itime)\n", - " neg_ct_msk = np.logical_and(celltype_msk, neg_msk)\n", - " pos_ct_msk = np.logical_and(celltype_msk, pos_msk)\n", - " print(ct, time.time() - itime)\n", - " neg_ndces = np.where(neg_ct_msk)[0]\n", - " pos_ndces = np.where(pos_ct_msk)[0]\n", - " np.random.seed(42)\n", - " samp_neg_ndces = np.random.choice(neg_ndces, size=len(pos_ndces), replace=False)\n", - " samp_ndces.extend(samp_neg_ndces)\n", - " samp_ndces.extend(pos_ndces)\n", - " print(ct, time.time() - itime)\n", - "_metadata_df = _metadata_df.iloc[samp_ndces, :]\n", - "\n", + "lts = ['{}:{}-{}'.format(x[0], x[1], x[2]) for x in zip(metadata_df['chr'], metadata_df['start'], metadata_df['stop'])]\n", + "print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "202800" + ] + }, + "execution_count": 55, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "#_metadata_df['y']\n", + "# s = metadata_df.iloc[np.array(pos_msk), :]\n", + "ntry = s.iloc[5]\n", + "ntry['start'] + 12800\n", + "# s['chr'], s['start'], s['stop'] # np.unique(s['chr'], return_counts=True)\n", + "# all_df\n", + "# metadata_df" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "H1-hESC 7.871182918548584\n", + "H1-hESC 8.298529148101807\n", + "H1-hESC 9.57175898551941\n", + "HCT116 17.01794719696045\n", + "HCT116 17.36267113685608\n", + "HCT116 18.669682025909424\n", + "HeLa-S3 26.405478954315186\n", + "HeLa-S3 26.759119272232056\n", + "HeLa-S3 28.043395042419434\n", + "HepG2 35.623862981796265\n", + "HepG2 35.98245143890381\n", + "HepG2 37.29869079589844\n", + "K562 44.92080807685852\n", + "K562 45.256179332733154\n", + "K562 46.7364935874939\n", + "A549 54.39264512062073\n", + "A549 54.74424934387207\n", + "A549 56.03351712226868\n", + "GM12878 63.745240211486816\n", + "GM12878 64.1029920578003\n", + "GM12878 65.43286633491516\n" + ] + } + ], + "source": [ "train_regions_mask = np.isin(_metadata_df['chr'], _train_chroms)\n", "val_regions_mask = np.isin(_metadata_df['chr'], _test_chroms)\n", "train_celltype_mask = np.isin(_metadata_df['celltype'], _train_celltypes)\n", @@ -629,17 +1058,13 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Initialize dataset object" + "# Dataset object (long version)" ] }, { "cell_type": "code", - "execution_count": 106, - "metadata": { - "jupyter": { - "source_hidden": true - } - }, + "execution_count": 3, + "metadata": {}, "outputs": [], "source": [ "import os, time\n", @@ -719,11 +1144,14 @@ " \n", " self._dnase_allcelltypes = {}\n", " for ct in self._all_celltypes:\n", + " \"\"\"\n", " dnase_filename = os.path.join(self._data_dir, '{}_dnase.npz'.format(ct))\n", " dnase_npz_contents = np.load(dnase_filename)\n", " self._dnase_allcelltypes[ct] = {}\n", " for chrom in self._all_chroms: #self._seq_bp:\n", " self._dnase_allcelltypes[ct][chrom] = dnase_npz_contents[chrom]\n", + " \"\"\"\n", + " self._dnase_allcelltypes[ct] = os.path.join(self._data_dir, 'DNASE.{}.fc.signal.bigwig'.format(ct))\n", " print(ct, time.time() - itime)\n", " \n", " # Read in metadata dataframe from training+validation data\n", @@ -733,24 +1161,22 @@ " val_df = val_regions_labeled[np.isin(val_regions_labeled['chr'], self._test_chroms)]\n", " all_df = pd.concat([training_df, val_df])\n", " \n", - " # Filter by start/stop coordinate if needed (TODO: remove for final version)\n", - " # filter_msk = all_df['start'] >= 0\n", - " # filter_msk = all_df['start']%1000 == 0\n", - " # all_df = all_df[filter_msk]\n", - " \n", + " # Get the y values, and remove ambiguous labels by default.\n", " pd_list = []\n", - " for ct in self._all_celltypes:\n", + " for ct in _all_celltypes:\n", " tc_chr = all_df[['chr', 'start', 'stop', ct]]\n", " tc_chr.columns = ['chr', 'start', 'stop', 'y']\n", + " y_array = tc_chr['y'].replace({'U': 0, 'B': 1, 'A': -1}).values\n", + "\n", + " # Now filter out ambiguous labels\n", + " non_ambig_mask = (y_array != -1)\n", + " tc_chr['y'] = y_array\n", + " tc_chr = tc_chr[non_ambig_mask]\n", + "\n", " tc_chr.insert(len(tc_chr.columns), 'celltype', ct)\n", " pd_list.append(tc_chr)\n", - " metadata_df = pd.concat(pd_list)\n", - " \n", - " # Get the y values, and remove ambiguous labels by default.\n", - " y_array = metadata_df['y'].replace({'U': 0, 'B': 1, 'A': -1}).values\n", - " non_ambig_mask = (y_array != -1)\n", - " metadata_df['y'] = y_array\n", - " self._metadata_df = metadata_df[non_ambig_mask]\n", + " print(time.time() - itime)\n", + " self._metadata_df = pd.concat(pd_list)\n", " \n", " # Downsample negatives to balance each celltype\n", " samp_ndces = []\n", @@ -814,20 +1240,30 @@ "\n", " def get_input(self, idx):\n", " \"\"\"\n", - " Returns x for a given idx.\n", + " Returns x for a given idx in metadata_array, which has been filtered to only take windows with the desired stride.\n", " Computes this from: \n", " (1) sequence features in self._seq_bp\n", - " (2) DNase features in self._dnase_allcelltypes\n", + " (2) DNase bigwig file paths in self._dnase_allcelltypes\n", " (3) Metadata for the index (location along the genome with 200bp window width)\n", " \"\"\"\n", + " \n", " this_metadata = self._metadata_df.iloc[idx, :]\n", + " \"\"\"\n", " flank_size = 400\n", " interval_start = this_metadata['start'] - flank_size\n", " interval_end = this_metadata['stop'] + flank_size\n", " dnase_this = self._dnase_allcelltypes[this_metadata['celltype']][this_metadata['chr']][interval_start:interval_end]\n", " seq_this = self._seq_bp[this_metadata['chr']][interval_start:interval_end]\n", " return torch.tensor(np.column_stack([seq_this, dnase_this]))\n", - "\n", + " \"\"\"\n", + " window_size = 12800\n", + " interval_start = this_metadata['start']\n", + " interval_end = this_metadata['stop'] + window_size\n", + " seq_this = self._seq_bp[this_metadata['chr']][interval_start:interval_end]\n", + " dnase_bw = self._dnase_allcelltypes[this_metadata['celltype']]\n", + " dnase_this = dnase_bw.values(chrom, interval_start, interval_end, numpy=True)\n", + " return torch.tensor(np.column_stack([seq_this, dnase_this]))\n", + " \n", " def eval(self, y_pred, y_true, metadata):\n", " return self.standard_group_eval(\n", " self._metric,\n", @@ -838,7 +1274,12 @@ { "cell_type": "code", "execution_count": 107, - "metadata": {}, + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + } + }, "outputs": [ { "name": "stdout", @@ -891,59 +1332,41 @@ }, { "cell_type": "code", - "execution_count": 118, - "metadata": {}, - "outputs": [], - "source": [ - "# full_dataset = copy.deepcopy(full_dataset_encode)\n", - "full_dataset = copy.deepcopy(full_dataset_camelyon17)\n", - "# full_dataset_camelyon17.split_dict" - ] - }, - { - "cell_type": "code", - "execution_count": 39, - "metadata": {}, - "outputs": [], - "source": [ - "a = np.random.choice(1210796, size=128)\n", - "seta = [full_dataset_encode.get_input(x) for x in a]\n", - "seta[0].shape" - ] - }, - { - "cell_type": "code", - "execution_count": 111, - "metadata": {}, - "outputs": [], - "source": [ - "full_dataset.metadata_fields\n", - "config = config_camelyon\n", - "#config_encode.groupby_fields\n", - "\n", - "train_grouper = CombinatorialGrouper(\n", - " dataset=full_dataset,\n", - " groupby_fields=config.groupby_fields)" - ] - }, - { - "cell_type": "code", - "execution_count": 104, + "execution_count": 2, "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 104, - "metadata": {}, - "output_type": "execute_result" + "ename": "ModuleNotFoundError", + "evalue": "No module named 'pyBigWig'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# full_dataset_encode\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mpyBigWig\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'pyBigWig'" + ] } ], "source": [ - "full_dataset" + "# full_dataset_encode\n", + "import pyBigWig" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [], + "source": [ + "a = np.random.choice(1210796, size=128)\n", + "seta = [full_dataset_encode.get_input(x) for x in a]\n", + "seta[0].shape\n", + "\n", + "# full_dataset = copy.deepcopy(full_dataset_encode)\n", + "# full_dataset = copy.deepcopy(full_dataset_camelyon17)\n", + "# full_dataset_camelyon17.split_dict\n", + "\n", + "# full_dataset" ] }, { @@ -955,7 +1378,7 @@ }, { "cell_type": "code", - "execution_count": 113, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -988,27 +1411,16 @@ " hospital = 4: n = 0\n", "Dout: 2\n" ] - }, - { - "ename": "RuntimeError", - "evalue": "CUDA error: out of memory", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 60\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 61\u001b[0m \u001b[0;31m## Initialize algorithm\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 62\u001b[0;31m algorithm = initialize_algorithm(\n\u001b[0m\u001b[1;32m 63\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 64\u001b[0m \u001b[0mdatasets\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdatasets\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/wilds/examples/algorithms/initializer.py\u001b[0m in \u001b[0;36minitialize_algorithm\u001b[0;34m(config, datasets, train_grouper)\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0malgorithm\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;34m'ERM'\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 32\u001b[0;31m algorithm = ERM(\n\u001b[0m\u001b[1;32m 33\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[0md_out\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0md_out\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/wilds/examples/algorithms/ERM.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, config, d_out, grouper, loss, metric, n_train_steps)\u001b[0m\n\u001b[1;32m 6\u001b[0m def __init__(self, config, d_out, grouper, loss,\n\u001b[1;32m 7\u001b[0m metric, n_train_steps):\n\u001b[0;32m----> 8\u001b[0;31m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minitialize_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0md_out\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 9\u001b[0m \u001b[0;31m# initialize module\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m super().__init__(\n", - "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36mto\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 610\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_floating_point\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_blocking\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 611\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 612\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconvert\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 613\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 614\u001b[0m def register_backward_hook(\n", - "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m 357\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 358\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 359\u001b[0;31m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 360\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 361\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m 357\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 358\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mchildren\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 359\u001b[0;31m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 360\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 361\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtensor_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_apply\u001b[0;34m(self, fn)\u001b[0m\n\u001b[1;32m 379\u001b[0m \u001b[0;31m# `with torch.no_grad():`\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 380\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mno_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 381\u001b[0;31m \u001b[0mparam_applied\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparam\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 382\u001b[0m \u001b[0mshould_use_set_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompute_should_use_set_data\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparam\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparam_applied\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 383\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mshould_use_set_data\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36mconvert\u001b[0;34m(t)\u001b[0m\n\u001b[1;32m 608\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mconvert_to_format\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m4\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 609\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_floating_point\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_blocking\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmemory_format\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mconvert_to_format\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 610\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_floating_point\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_blocking\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 611\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 612\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconvert\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mRuntimeError\u001b[0m: CUDA error: out of memory" - ] } ], "source": [ + "config = config_camelyon\n", + "\n", + "\n", + "train_grouper = CombinatorialGrouper(\n", + " dataset=full_dataset,\n", + " groupby_fields=config.groupby_fields)\n", + "\n", "datasets = defaultdict(dict)\n", "for split in full_dataset.split_dict.keys():\n", " if split=='train':\n", @@ -1078,188 +1490,31 @@ }, { "cell_type": "code", - "execution_count": 91, + "execution_count": 29, "metadata": {}, "outputs": [ { "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
chrstartstopycelltypesplit
3831225chr11917992501917994500H1-hESC1
4190052chr12097406002097408000H1-hESC1
7241915chr866306500663067000H1-hESC1
21449377chr238487450384876500H1-hESC0
45876013chr9569770056979000H1-hESC0
.....................
8841297chr81462777501462779501GM128782
8841298chr81462778001462780001GM128782
8841299chr81462778501462780501GM128782
8841300chr81462779001462781001GM128782
8841301chr81462779501462781501GM128782
\n", - "

1210796 rows × 6 columns

\n", - "
" - ], "text/plain": [ - " chr start stop y celltype split\n", - "3831225 chr1 191799250 191799450 0 H1-hESC 1\n", - "4190052 chr1 209740600 209740800 0 H1-hESC 1\n", - "7241915 chr8 66306500 66306700 0 H1-hESC 1\n", - "21449377 chr2 38487450 38487650 0 H1-hESC 0\n", - "45876013 chr9 5697700 5697900 0 H1-hESC 0\n", - "... ... ... ... .. ... ...\n", - "8841297 chr8 146277750 146277950 1 GM12878 2\n", - "8841298 chr8 146277800 146278000 1 GM12878 2\n", - "8841299 chr8 146277850 146278050 1 GM12878 2\n", - "8841300 chr8 146277900 146278100 1 GM12878 2\n", - "8841301 chr8 146277950 146278150 1 GM12878 2\n", - "\n", - "[1210796 rows x 6 columns]" + "" ] }, - "execution_count": 91, + "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# algorithm.device\n", - "_metadata_df\n", + "full_dataset\n", "# datasets['train']['loader']" ] }, { "cell_type": "code", - "execution_count": 90, + "execution_count": 15, "metadata": {}, - "outputs": [ - { - "ename": "NameError", - "evalue": "name 'datasets' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdatasets\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'train'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'loader'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmetadata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mNameError\u001b[0m: name 'datasets' is not defined" - ] - } - ], + "outputs": [], "source": [ "for batch in datasets['train']['loader']:\n", " x, y_true, metadata = batch\n", @@ -1268,36 +1523,70 @@ }, { "cell_type": "code", - "execution_count": 134, - "metadata": { - "collapsed": true, - "jupyter": { - "outputs_hidden": true + "execution_count": 43, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0,\n", + " 0, 1, 1, 1, 0, 0, 0, 0])" + ] + }, + "execution_count": 43, + "metadata": {}, + "output_type": "execute_result" } - }, + ], + "source": [ + "y_true" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, "outputs": [ { - "ename": "RuntimeError", - "evalue": "CUDA out of memory. Tried to allocate 14.00 MiB (GPU 0; 11.93 GiB total capacity; 10.94 GiB already allocated; 5.06 MiB free; 11.32 GiB reserved in total by PyTorch)", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# datasets['train']['dataset'].size()\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0malgorithm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0malgorithm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 725\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 728\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 729\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torchvision/models/densenet.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 190\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 191\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 192\u001b[0;31m \u001b[0mfeatures\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfeatures\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 193\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfeatures\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minplace\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 194\u001b[0m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madaptive_avg_pool2d\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 725\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 728\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 729\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/modules/container.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 115\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 117\u001b[0;31m \u001b[0minput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 118\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 119\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 725\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 728\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 729\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torchvision/models/densenet.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, init_features)\u001b[0m\n\u001b[1;32m 109\u001b[0m \u001b[0mfeatures\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0minit_features\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 110\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlayer\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 111\u001b[0;31m \u001b[0mnew_features\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlayer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfeatures\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 112\u001b[0m \u001b[0mfeatures\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnew_features\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfeatures\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 725\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 728\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 729\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torchvision/models/densenet.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 82\u001b[0m \u001b[0mbottleneck_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcall_checkpoint_bottleneck\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprev_features\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 83\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 84\u001b[0;31m \u001b[0mbottleneck_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbn_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprev_features\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 85\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[0mnew_features\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconv2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelu2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbottleneck_output\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torchvision/models/densenet.py\u001b[0m in \u001b[0;36mbn_function\u001b[0;34m(self, inputs)\u001b[0m\n\u001b[1;32m 39\u001b[0m \u001b[0;31m# type: (List[Tensor]) -> Tensor\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 40\u001b[0m \u001b[0mconcated_features\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 41\u001b[0;31m \u001b[0mbottleneck_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconv1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrelu1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mconcated_features\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# noqa: T484\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 42\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mbottleneck_output\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 725\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 726\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 728\u001b[0m for hook in itertools.chain(\n\u001b[1;32m 729\u001b[0m \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/modules/batchnorm.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m 129\u001b[0m \u001b[0mused\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mnormalization\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m \u001b[0;32min\u001b[0m \u001b[0meval\u001b[0m \u001b[0mmode\u001b[0m \u001b[0mwhen\u001b[0m \u001b[0mbuffers\u001b[0m \u001b[0mare\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 130\u001b[0m \"\"\"\n\u001b[0;32m--> 131\u001b[0;31m return F.batch_norm(\n\u001b[0m\u001b[1;32m 132\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 133\u001b[0m \u001b[0;31m# If buffers are not to be tracked, ensure that they won't be updated\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds1/lib/python3.8/site-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36mbatch_norm\u001b[0;34m(input, running_mean, running_var, weight, bias, training, momentum, eps)\u001b[0m\n\u001b[1;32m 2054\u001b[0m \u001b[0m_verify_batch_size\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2055\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2056\u001b[0;31m return torch.batch_norm(\n\u001b[0m\u001b[1;32m 2057\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrunning_mean\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrunning_var\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2058\u001b[0m \u001b[0mtraining\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmomentum\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0meps\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackends\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcudnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0menabled\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mRuntimeError\u001b[0m: CUDA out of memory. Tried to allocate 14.00 MiB (GPU 0; 11.93 GiB total capacity; 10.94 GiB already allocated; 5.06 MiB free; 11.32 GiB reserved in total by PyTorch)" - ] + "data": { + "text/plain": [ + "tensor([[ 0.1406, -0.0628],\n", + " [ 0.0534, 0.0359],\n", + " [-0.0174, -0.0097],\n", + " [-0.0571, -0.2381],\n", + " [ 0.1590, -0.0559],\n", + " [ 0.1254, -0.0139],\n", + " [-0.0423, 0.0439],\n", + " [ 0.1621, 0.0730],\n", + " [ 0.0554, 0.0796],\n", + " [-0.0532, 0.0667],\n", + " [-0.1927, -0.0387],\n", + " [ 0.1352, -0.0385],\n", + " [-0.1320, 0.0140],\n", + " [-0.0531, -0.1171],\n", + " [-0.0378, -0.0134],\n", + " [ 0.1047, 0.0298],\n", + " [ 0.0355, -0.0497],\n", + " [ 0.1065, -0.0218],\n", + " [-0.1883, 0.1298],\n", + " [ 0.0699, -0.0875],\n", + " [-0.1233, 0.1793],\n", + " [ 0.0151, 0.0708],\n", + " [-0.0973, -0.0033],\n", + " [ 0.1027, -0.2456],\n", + " [ 0.0433, -0.0441],\n", + " [ 0.1013, -0.1020],\n", + " [ 0.1309, -0.0051],\n", + " [ 0.0028, -0.0558],\n", + " [ 0.0635, 0.0575],\n", + " [-0.0066, 0.0666],\n", + " [-0.0076, -0.0375],\n", + " [ 0.1336, 0.0024]], device='cuda:0', grad_fn=)" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ @@ -1387,13 +1676,6 @@ "outputs": [], "source": [] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, { "cell_type": "code", "execution_count": 126, @@ -1418,14 +1700,14 @@ { "cell_type": "code", "execution_count": 33, - "metadata": {}, + "metadata": { + "jupyter": { + "source_hidden": true + } + }, "outputs": [], "source": [ - "import math\n", - "import numpy as np\n", - "import torch\n", - "import torch.nn as nn\n", - "import torch.nn.functional as F\n", + "\n", "\n", "class Beagle(nn.Module):\n", " \"\"\"\n", @@ -1487,10 +1769,17 @@ }, { "cell_type": "code", - "execution_count": 100, + "execution_count": 19, "metadata": {}, "outputs": [], "source": [ + "import math\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "\n", "def double_conv(in_channels, out_channels): \n", " return nn.Sequential(\n", " nn.Conv1d(in_channels, out_channels, 7, padding=2), \n", @@ -1572,33 +1861,13 @@ }, { "cell_type": "code", - "execution_count": 101, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "485773" - ] - }, - "execution_count": 101, - "metadata": {}, - "output_type": "execute_result" + "execution_count": 20, + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true } - ], - "source": [ - "model = UNet(2)\n", - "#model = DanQ(50, 5)\n", - "\n", - "lst = [(x[0], x[1].numel()) for x in model.named_parameters()]\n", - "#np.sum([x[1] for x in lst])\n", - "count_parameters(model)" - ] - }, - { - "cell_type": "code", - "execution_count": 102, - "metadata": {}, + }, "outputs": [ { "data": { @@ -1698,43 +1967,593 @@ ")" ] }, - "execution_count": 102, + "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ + "model = UNet(2)\n", "model" ] }, { "cell_type": "code", - "execution_count": 96, + "execution_count": 101, "metadata": {}, "outputs": [ { - "ename": "NameError", - "evalue": "name 'Beagle' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mp\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparameters\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrequires_grad\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mBeagle\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0;31m#model = DanQ(50, 5)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mNameError\u001b[0m: name 'Beagle' is not defined" - ] + "data": { + "text/plain": [ + "485773" + ] + }, + "execution_count": 101, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ "def count_parameters(model):\n", - " return sum(p.numel() for p in model.parameters() if p.requires_grad)" + " return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", + "\n", + "lst = [(x[0], x[1].numel()) for x in model.named_parameters()]\n", + "#np.sum([x[1] for x in lst])\n", + "count_parameters(model)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "6955906" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "count_parameters(algorithm.model)" + ] + }, + { + "cell_type": "code", + "execution_count": 27, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "lst = [(x[0], x[1].numel()) for x in algorithm.model.named_parameters()]" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + } + }, + "outputs": [ + { + "data": { + "text/plain": [ + "DenseNet(\n", + " (features): Sequential(\n", + " (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n", + " (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu0): ReLU(inplace=True)\n", + " (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n", + " (denseblock1): _DenseBlock(\n", + " (denselayer1): _DenseLayer(\n", + " (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer2): _DenseLayer(\n", + " (norm1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(96, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer3): _DenseLayer(\n", + " (norm1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer4): _DenseLayer(\n", + " (norm1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(160, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer5): _DenseLayer(\n", + " (norm1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer6): _DenseLayer(\n", + " (norm1): BatchNorm2d(224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(224, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " )\n", + " (transition1): _Transition(\n", + " (norm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): ReLU(inplace=True)\n", + " (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n", + " )\n", + " (denseblock2): _DenseBlock(\n", + " (denselayer1): _DenseLayer(\n", + " (norm1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer2): _DenseLayer(\n", + " (norm1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(160, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer3): _DenseLayer(\n", + " (norm1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer4): _DenseLayer(\n", + " (norm1): BatchNorm2d(224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(224, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer5): _DenseLayer(\n", + " (norm1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer6): _DenseLayer(\n", + " (norm1): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(288, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer7): _DenseLayer(\n", + " (norm1): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer8): _DenseLayer(\n", + " (norm1): BatchNorm2d(352, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(352, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer9): _DenseLayer(\n", + " (norm1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(384, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer10): _DenseLayer(\n", + " (norm1): BatchNorm2d(416, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(416, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer11): _DenseLayer(\n", + " (norm1): BatchNorm2d(448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(448, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer12): _DenseLayer(\n", + " (norm1): BatchNorm2d(480, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(480, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " )\n", + " (transition2): _Transition(\n", + " (norm): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): ReLU(inplace=True)\n", + " (conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n", + " )\n", + " (denseblock3): _DenseBlock(\n", + " (denselayer1): _DenseLayer(\n", + " (norm1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer2): _DenseLayer(\n", + " (norm1): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(288, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer3): _DenseLayer(\n", + " (norm1): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer4): _DenseLayer(\n", + " (norm1): BatchNorm2d(352, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(352, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer5): _DenseLayer(\n", + " (norm1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(384, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer6): _DenseLayer(\n", + " (norm1): BatchNorm2d(416, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(416, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer7): _DenseLayer(\n", + " (norm1): BatchNorm2d(448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(448, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer8): _DenseLayer(\n", + " (norm1): BatchNorm2d(480, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(480, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer9): _DenseLayer(\n", + " (norm1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer10): _DenseLayer(\n", + " (norm1): BatchNorm2d(544, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(544, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer11): _DenseLayer(\n", + " (norm1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer12): _DenseLayer(\n", + " (norm1): BatchNorm2d(608, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer13): _DenseLayer(\n", + " (norm1): BatchNorm2d(640, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(640, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer14): _DenseLayer(\n", + " (norm1): BatchNorm2d(672, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(672, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer15): _DenseLayer(\n", + " (norm1): BatchNorm2d(704, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(704, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer16): _DenseLayer(\n", + " (norm1): BatchNorm2d(736, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(736, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer17): _DenseLayer(\n", + " (norm1): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(768, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer18): _DenseLayer(\n", + " (norm1): BatchNorm2d(800, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(800, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer19): _DenseLayer(\n", + " (norm1): BatchNorm2d(832, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(832, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer20): _DenseLayer(\n", + " (norm1): BatchNorm2d(864, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(864, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer21): _DenseLayer(\n", + " (norm1): BatchNorm2d(896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer22): _DenseLayer(\n", + " (norm1): BatchNorm2d(928, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(928, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer23): _DenseLayer(\n", + " (norm1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(960, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer24): _DenseLayer(\n", + " (norm1): BatchNorm2d(992, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(992, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " )\n", + " (transition3): _Transition(\n", + " (norm): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): ReLU(inplace=True)\n", + " (conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n", + " )\n", + " (denseblock4): _DenseBlock(\n", + " (denselayer1): _DenseLayer(\n", + " (norm1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer2): _DenseLayer(\n", + " (norm1): BatchNorm2d(544, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(544, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer3): _DenseLayer(\n", + " (norm1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer4): _DenseLayer(\n", + " (norm1): BatchNorm2d(608, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer5): _DenseLayer(\n", + " (norm1): BatchNorm2d(640, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(640, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer6): _DenseLayer(\n", + " (norm1): BatchNorm2d(672, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(672, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer7): _DenseLayer(\n", + " (norm1): BatchNorm2d(704, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(704, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer8): _DenseLayer(\n", + " (norm1): BatchNorm2d(736, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(736, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer9): _DenseLayer(\n", + " (norm1): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(768, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer10): _DenseLayer(\n", + " (norm1): BatchNorm2d(800, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(800, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer11): _DenseLayer(\n", + " (norm1): BatchNorm2d(832, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(832, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer12): _DenseLayer(\n", + " (norm1): BatchNorm2d(864, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(864, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer13): _DenseLayer(\n", + " (norm1): BatchNorm2d(896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer14): _DenseLayer(\n", + " (norm1): BatchNorm2d(928, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(928, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer15): _DenseLayer(\n", + " (norm1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(960, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer16): _DenseLayer(\n", + " (norm1): BatchNorm2d(992, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(992, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " )\n", + " (norm5): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " (classifier): Linear(in_features=1024, out_features=2, bias=True)\n", + ")" + ] + }, + "execution_count": 39, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "algorithm.model" + ] }, { "cell_type": "code", diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index 08cba281..04f5d08d 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -2,6 +2,7 @@ import torch import pandas as pd import numpy as np +import pyBigWig from wilds.datasets.wilds_dataset import WILDSDataset from wilds.common.grouper import CombinatorialGrouper from wilds.common.metrics.all_metrics import Accuracy @@ -76,6 +77,9 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): print(chrom, time.time() - itime) self._dnase_allcelltypes = {} + ct = 'avg' + dnase_avg_bw_path = os.path.join(self._data_dir, 'Leopard_dnase/{}.bigwig'.format(ct)) + self._dnase_allcelltypes[ct] = pyBigWig.open(dnase_avg_bw_path) for ct in self._all_celltypes: """ dnase_filename = os.path.join(self._data_dir, '{}_dnase.npz'.format(ct)) @@ -84,8 +88,8 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): for chrom in self._all_chroms: #self._seq_bp: self._dnase_allcelltypes[ct][chrom] = dnase_npz_contents[chrom] """ - self._dnase_allcelltypes[ct] = 'DNASE.{}.fc.signal.bigwig' - print(ct, time.time() - itime) + dnase_bw_path = os.path.join(self._data_dir, 'Leopard_dnase/{}.bigwig'.format(ct)) + self._dnase_allcelltypes[ct] = pyBigWig.open(dnase_bw_path) # Read in metadata dataframe from training+validation data train_regions_labeled = pd.read_csv(os.path.join(self._data_dir, 'labels/{}.train.labels.tsv.gz'.format(self._transcription_factor)), sep='\t') @@ -94,34 +98,36 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): val_df = val_regions_labeled[np.isin(val_regions_labeled['chr'], self._test_chroms)] all_df = pd.concat([training_df, val_df]) - # Filter by start/stop coordinate if needed (TODO: remove for final version) - """ - filter_msk = all_df['start'] >= 0 - filter_msk = all_df['start']%1000 == 0 - all_df = all_df[filter_msk] - """ - + # Get the y values, and remove ambiguous labels by default. pd_list = [] for ct in self._all_celltypes: tc_chr = all_df[['chr', 'start', 'stop', ct]] tc_chr.columns = ['chr', 'start', 'stop', 'y'] + y_array = tc_chr['y'].replace({'U': 0, 'B': 1, 'A': -1}).values + + # Now filter out ambiguous labels + non_ambig_mask = (y_array != -1) + tc_chr['y'] = y_array + tc_chr = tc_chr[non_ambig_mask] + tc_chr.insert(len(tc_chr.columns), 'celltype', ct) pd_list.append(tc_chr) - metadata_df = pd.concat(pd_list) - - # Get the y values, and remove ambiguous labels by default. - y_array = metadata_df['y'].replace({'U': 0, 'B': 1, 'A': -1}).values - non_ambig_mask = (y_array != -1) - metadata_df['y'] = y_array - self._metadata_df = metadata_df[non_ambig_mask] + print(time.time() - itime) + self._metadata_df = pd.concat(pd_list) + # Downsample negatives to balance each celltype samp_ndces = [] itime = time.time() - for ct in self._all_celltypes: - neg_msk = np.logical_and((self._metadata_df['celltype'] == ct), (self._metadata_df['y'] == 0)) - pos_msk = np.logical_and((self._metadata_df['celltype'] == ct), (self._metadata_df['y'] == 1)) - neg_ndces = np.where(neg_msk)[0] - pos_ndces = np.where(pos_msk)[0] + neg_msk = (self._metadata_df['y'] == 0) + pos_msk = (self._metadata_df['y'] == 1) + for ct in _all_celltypes: + celltype_msk = (self._metadata_df['celltype'] == ct) + print(ct, time.time() - itime) + neg_ct_msk = np.logical_and(celltype_msk, neg_msk) + pos_ct_msk = np.logical_and(celltype_msk, pos_msk) + print(ct, time.time() - itime) + neg_ndces = np.where(neg_ct_msk)[0] + pos_ndces = np.where(pos_ct_msk)[0] np.random.seed(42) samp_neg_ndces = np.random.choice(neg_ndces, size=len(pos_ndces), replace=False) samp_ndces.extend(samp_neg_ndces) @@ -169,21 +175,46 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): super().__init__(root_dir, download, split_scheme) + def get_random_label_vec(metadata_df, output_size=128): + # Sample a positively labeled region at random + pos_mdf = metadata_df[metadata_df['y'] == 1] #.iloc[ metadata_df['chr'] == s['chr'], : ] + pos_seed_region = pos_mdf.iloc[np.random.randint(pos_mdf.shape[0])] + + # Extract regions from this chromosome in this celltype, to get a window of labels from + chr_msk = np.array(metadata_df['chr']) == pos_seed_region['chr'] + ct_msk = np.array(metadata_df['celltype']) == pos_seed_region['celltype'] + mdf = metadata_df[chr_msk & ct_msk] + + # Get labels + start_ndx = np.where(mdf['start'] == pos_seed_region['start'])[0][0] + y_label_vec = mdf.iloc[start_ndx:start_ndx+output_size, :]['y'] + def get_input(self, idx): """ - Returns x for a given idx. + Returns x for a given idx in metadata_array, which has been filtered to only take windows with the desired stride. Computes this from: (1) sequence features in self._seq_bp - (2) DNase features in self._dnase_allcelltypes + (2) DNase bigwig file handles in self._dnase_allcelltypes (3) Metadata for the index (location along the genome with 200bp window width) """ + this_metadata = self._metadata_df.iloc[idx, :] + """ flank_size = 400 interval_start = this_metadata['start'] - flank_size interval_end = this_metadata['stop'] + flank_size dnase_this = self._dnase_allcelltypes[this_metadata['celltype']][this_metadata['chr']][interval_start:interval_end] seq_this = self._seq_bp[this_metadata['chr']][interval_start:interval_end] return torch.tensor(np.column_stack([seq_this, dnase_this])) + """ + window_size = 12800 + interval_start = this_metadata['start'] + interval_end = window_size + interval_start #this_metadata['stop'] + seq_this = self._seq_bp[this_metadata['chr']][interval_start:interval_end] + dnase_bw = self._dnase_allcelltypes[this_metadata['celltype']] + dnase_this = dnase_bw.values(chrom, interval_start, interval_end, numpy=True) + dnase_avg = self._dnase_allcelltypes['avg'].values(chrom, interval_start, interval_end, numpy=True) + return torch.tensor(np.column_stack([seq_this, dnase_this, dnase_avg])) def eval(self, y_pred, y_true, metadata): return self.standard_group_eval( From 3d2351ee62c5c6b5f37081fb536bfc19ffde02e8 Mon Sep 17 00:00:00 2001 From: aikanor Date: Thu, 4 Mar 2021 20:12:23 -0800 Subject: [PATCH 073/244] final code (1/3) except eval, model fixes --- dataset_preprocessing/encode-tfbs/README.md | 8 +- .../encode-tfbs/prep_metadata_labels.ipynb | 382 ++++++ .../encode-tfbs/write_label_bigwig.py | 93 ++ examples/configs/datasets.py | 2 +- examples/sbox_run_expt.ipynb | 1142 +++++------------ wilds/datasets/encodetfbs_dataset.py | 103 +- 6 files changed, 836 insertions(+), 894 deletions(-) create mode 100644 dataset_preprocessing/encode-tfbs/prep_metadata_labels.ipynb create mode 100644 dataset_preprocessing/encode-tfbs/write_label_bigwig.py diff --git a/dataset_preprocessing/encode-tfbs/README.md b/dataset_preprocessing/encode-tfbs/README.md index bf3f92c6..7ecf1135 100644 --- a/dataset_preprocessing/encode-tfbs/README.md +++ b/dataset_preprocessing/encode-tfbs/README.md @@ -9,11 +9,11 @@ 2. Run `python prep_sequence.py --seq_path SEQUENCE_PATH --output_dir OUTPUT_DIR` to write the fasta file found in `SEQUENCE_PATH` to a numpy array archive in `OUTPUT_DIR`. -3. Download the accessibility data from the challenge. This consists of whole-genome DNase files in bigwig format (*.bw) from https://www.synapse.org/#!Synapse:syn6176233. +3. Download the DNase accessibility data. This consists of whole-genome DNase files in bigwig format from https://guanfiles.dcmb.med.umich.edu/Leopard/dnase_bigwig/. -4. Run `python prep_accessibility.py --input_dir INPUT_DIR --output_dir OUTPUT_DIR` to extract the bigwigs into numpy array archives, one per celltype. - -5. Download the labels from the challenge into a label directory created for this purpose: +4. Download the labels from the challenge into a label directory created for this purpose: - The training labels from https://www.synapse.org/#!Synapse:syn7413983 for the relevant transcription factor (e.g. https://www.synapse.org/#!Synapse:syn7415202 for the TF MAX). - The validation labels from https://www.synapse.org/#!Synapse:syn8441154 for the relevant transcription factor (e.g. https://www.synapse.org/#!Synapse:syn8442103 for the TF MAX). - (Optional) The validation labels for the challenge's evaluation cell type from https://www.synapse.org/#!Synapse:syn8442975 for the relevant transcription factor (generally primary liver cells, e.g. https://www.synapse.org/#!Synapse:syn8443021 for the TF MAX). + +5. Run `write_label_bigwig.py` diff --git a/dataset_preprocessing/encode-tfbs/prep_metadata_labels.ipynb b/dataset_preprocessing/encode-tfbs/prep_metadata_labels.ipynb new file mode 100644 index 00000000..9748bd25 --- /dev/null +++ b/dataset_preprocessing/encode-tfbs/prep_metadata_labels.ipynb @@ -0,0 +1,382 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import os, csv\n", + "import scipy, numpy as np, pandas as pd, time\n", + "from scipy import sparse\n", + "import pyBigWig\n", + "\n", + "# Human chromosome names\n", + "chr_IDs = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX']" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prep metadata df and metadata array surrounding the labels\n", + "- Metadata df contains 6400bp (window_size/2) prediction windows across the genome. Each gets a 128-bit prediction from the model.\n", + "- We store the ones that aren't fully unbound. All the rest are fully unbound." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "83.30138063430786\n", + "H1-hESC 100.73247504234314\n", + "HCT116 106.4023334980011\n", + "HeLa-S3 111.88021206855774\n", + "HepG2 117.56940197944641\n", + "K562 126.93423342704773\n", + "A549 138.21517205238342\n", + "GM12878 148.77391648292542\n", + "150.62964010238647\n", + "213.72714066505432\n" + ] + } + ], + "source": [ + "itime = time.time()\n", + "\n", + "_data_dir = '../../examples/data/encode-tfbs_v1.0/'\n", + "_transcription_factor = 'MAX'\n", + "_train_chroms = ['chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr10', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX']\n", + "_val_chroms = ['chr2', 'chr9', 'chr11']\n", + "_test_chroms = ['chr1', 'chr8', 'chr21']\n", + "_all_chroms = _train_chroms + _val_chroms + _test_chroms\n", + "_train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']\n", + "_val_celltype = ['A549']\n", + "_test_celltype = ['GM12878']\n", + "_all_celltypes = _train_celltypes + _val_celltype + _test_celltype\n", + "\n", + "# Read in metadata dataframe from training+validation data\n", + "train_regions_labeled = pd.read_csv(os.path.join(_data_dir, 'labels/{}.train.labels.tsv.gz'.format(_transcription_factor)), sep='\\t')\n", + "val_regions_labeled = pd.read_csv(os.path.join(_data_dir, 'labels/{}.val.labels.tsv.gz'.format(_transcription_factor)), sep='\\t')\n", + "training_df = train_regions_labeled# [np.isin(train_regions_labeled['chr'], _train_chroms)]\n", + "val_df = val_regions_labeled# [np.isin(val_regions_labeled['chr'], _test_chroms)]\n", + "all_df = pd.concat([training_df, val_df])\n", + "\n", + "print(time.time() - itime)\n", + "\n", + "# Get the y values, and remove labels by default.\n", + "pd_list = []\n", + "for ct in _all_celltypes:\n", + " tc_chr = all_df[['chr', 'start', 'stop', ct]]\n", + " tc_chr.columns = ['chr', 'start', 'stop', 'y']\n", + " tc_chr = tc_chr[tc_chr['y'] != 'U']\n", + " tc_chr['y'] = tc_chr['y'].replace({'U': 0, 'B': 1, 'A': 0.5}).values\n", + " \n", + " tc_chr.insert(len(tc_chr.columns), 'celltype', ct)\n", + " pd_list.append(tc_chr)\n", + " print(ct, time.time() - itime)\n", + "_metadata_df = pd.concat(pd_list)\n", + "\n", + "print(time.time() - itime)\n", + "_unsorted_dir = _data_dir + 'labels/MAX/MAX_posamb.bed'\n", + "_sorted_dir = _unsorted_dir.replace('MAX_posamb', 'MAX_posamb.sorted')\n", + "_metadata_df.to_csv(\n", + " _unsorted_dir, sep='\\t', header=False, index=False\n", + ")\n", + "print(time.time() - itime)\n", + "\n", + "os.system('sort -k1,1 -k2,2n {} > {}'.format(_unsorted_dir, _sorted_dir))\n", + "\n", + "mdf_posamb = pd.read_csv(\n", + " _sorted_dir, \n", + " sep='\\t', header=None, index_col=None, names=['chr', 'start', 'stop', 'y', 'celltype']\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "H1-hESC 350.84476041793823\n", + "HCT116 358.2693498134613\n", + "HeLa-S3 364.6210968494415\n", + "HepG2 372.65956830978394\n", + "K562 380.6701240539551\n", + "A549 388.50364875793457\n", + "GM12878 394.2549338340759\n" + ] + } + ], + "source": [ + "chrom_sizes = {'chr1': 249250621, 'chr10': 135534747, 'chr11': 135006516, 'chr12': 133851895, 'chr13': 115169878, 'chr14': 107349540, 'chr15': 102531392, 'chr16': 90354753, 'chr17': 81195210, 'chr18': 78077248, 'chr19': 59128983, 'chr2': 243199373, 'chr20': 63025520, 'chr21': 48129895, 'chr22': 51304566, 'chr3': 198022430, 'chr4': 191154276, 'chr5': 180915260, 'chr6': 171115067, 'chr7': 159138663, 'chr8': 146364022, 'chr9': 141213431, 'chrX': 155270560}\n", + "chromsizes_list = [(k, v) for k, v in chrom_sizes.items()]\n", + "for ct in _all_celltypes:\n", + " ct_labels_bw_path = _data_dir + \"labels/MAX/MAX_{}.bigwig\".format(ct)\n", + " df = mdf_posamb[mdf_posamb['celltype'] == ct]\n", + " bw = pyBigWig.open(ct_labels_bw_path, \"w\")\n", + " bw.addHeader(chromsizes_list)\n", + " bw.addEntries(list(df['chr']), list(df['start']), ends=list(df['start']+50), values=list(df['y']))\n", + " print(ct, time.time() - itime)\n", + " bw.close()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + ":9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " df['window_start'] = stride*(df['start'] // stride)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "A549 63.97912120819092\n", + "GM12878 103.89278292655945\n", + "H1-hESC 182.84059262275696\n", + "HCT116 243.95744681358337\n", + "HeLa-S3 303.7187397480011\n", + "HepG2 375.8099205493927\n", + "K562 456.08897161483765\n", + "456.0923991203308\n", + "462.8749210834503\n" + ] + } + ], + "source": [ + "stride = 6400\n", + "itime = time.time()\n", + "celltype_mdta = []\n", + "celltype_labels = []\n", + "\n", + "for ct in _all_celltypes:\n", + " ct_labels_bw_path = _data_dir + \"labels/MAX/MAX_{}.bigwig\".format(ct)\n", + " df = mdf_posamb[mdf_posamb['celltype'] == ct]\n", + " df['window_start'] = stride*(df['start'] // stride)\n", + " uniq_windows = np.unique([\"{}:{}\".format(x[0], x[1]) for x in zip(df['chr'], df['window_start'])])\n", + " df_construction = []\n", + " mdta_labels = []\n", + " \n", + " bw = pyBigWig.open(ct_labels_bw_path)\n", + " num_reps = 0\n", + " for u in uniq_windows:\n", + " u_chr = u.split(':')[0]\n", + " u_start = int(u.split(':')[1])\n", + " u_end = u_start + stride\n", + " x = np.nan_to_num(bw.values(u_chr, u_start, u_end, numpy=True))\n", + " df_construction.append((u_chr, u_start, u_end))\n", + " mdta_labels.append(x[np.arange(0, len(x), 50)])\n", + " num_reps = num_reps + 1\n", + " celltype_mdta_df = pd.DataFrame(df_construction, columns=['chr', 'start', 'stop'])\n", + " celltype_mdta_df.insert(len(celltype_mdta_df.columns), 'celltype', ct)\n", + " celltype_mdta.append(celltype_mdta_df)\n", + " celltype_labels.append(np.stack(mdta_labels))\n", + " print(ct, time.time() - itime)\n", + " bw.close()\n", + " # break\n", + "print(time.time() - itime)\n", + "# _metadata_df\n", + "\n", + "pd.concat(celltype_mdta).to_csv(\n", + " _data_dir + 'labels/MAX/metadata_df.bed', \n", + " sep='\\t', header=False, index=False\n", + ")\n", + "np.save(_data_dir + 'labels/MAX/metadata_y.npy', np.vstack(celltype_labels))\n", + "print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
chrstartstopcelltype
0chr10100025600100032000A549
1chr10100032000100038400A549
2chr10100064000100070400A549
3chr10100076800100083200A549
4chr10100083200100089600A549
...............
523753chrX9969920099705600K562
523754chrX99776009984000K562
523755chrX9990400099910400K562
523756chrX9992320099929600K562
523757chrX99993600100000000K562
\n", + "

523758 rows × 4 columns

\n", + "
" + ], + "text/plain": [ + " chr start stop celltype\n", + "0 chr10 100025600 100032000 A549\n", + "1 chr10 100032000 100038400 A549\n", + "2 chr10 100064000 100070400 A549\n", + "3 chr10 100076800 100083200 A549\n", + "4 chr10 100083200 100089600 A549\n", + "... ... ... ... ...\n", + "523753 chrX 99699200 99705600 K562\n", + "523754 chrX 9977600 9984000 K562\n", + "523755 chrX 99904000 99910400 K562\n", + "523756 chrX 99923200 99929600 K562\n", + "523757 chrX 99993600 100000000 K562\n", + "\n", + "[523758 rows x 4 columns]" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pd.read_csv(\n", + " _data_dir + 'labels/MAX/metadata_df.bed', sep='\\t', header=None, \n", + " index_col=None, names=['chr', 'start', 'stop', 'celltype']\n", + ")\n", + "# np.load(_data_dir + 'labels/MAX/metadata_y.npy')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/dataset_preprocessing/encode-tfbs/write_label_bigwig.py b/dataset_preprocessing/encode-tfbs/write_label_bigwig.py new file mode 100644 index 00000000..8dcf4f3e --- /dev/null +++ b/dataset_preprocessing/encode-tfbs/write_label_bigwig.py @@ -0,0 +1,93 @@ +import argparse, time +import numpy as np +import pyBigWig + +# Human hg19 chromosome names/lengths +chrom_sizes = {'chr1': 249250621, 'chr10': 135534747, 'chr11': 135006516, 'chr12': 133851895, 'chr13': 115169878, 'chr14': 107349540, 'chr15': 102531392, 'chr16': 90354753, 'chr17': 81195210, 'chr18': 78077248, 'chr19': 59128983, 'chr2': 243199373, 'chr20': 63025520, 'chr21': 48129895, 'chr22': 51304566, 'chr3': 198022430, 'chr4': 191154276, 'chr5': 180915260, 'chr6': 171115067, 'chr7': 159138663, 'chr8': 146364022, 'chr9': 141213431, 'chrX': 155270560} + +celltypes = ['A549', 'GM12878', 'H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562'] + +chr_IDs = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX'] + +def write_label_bigwig( + metadata_df, output_dir='codalab_archive' +): + dnases = {} + for ctype in celltypes: + itime = time.time() + bw = pyBigWig.open("{}/DNASE.{}.fc.signal.bigwig".format(input_dir, ctype)) + chromsizes = bw.chroms() + dn_dict = {} + for chrom in chromsizes: #chr_IDs: + x = bw.values(chrom, 0, chromsizes[chrom], numpy=True) + # half-precision makes things significantly smaller (less time to load) + dn_dict[chrom] = np.nan_to_num(x).astype(np.float16) + print("{}, {}. Time: {}".format(ctype, chrom, time.time() - itime)) + dnases[ctype] = dn_dict + + for ctype in dnases: + itime = time.time() + dn_dict = dnases[ctype] + + # Save as npz archive + np.savez_compressed('{}/{}_dnase'.format(output_dir, ctype), **dn_dict) + print("Saving npz archive for celltype {}. Time: {}".format(ctype, time.time() - itime)) + + +if __name__ == '__main__': + itime = time.time() + _data_dir = '../../examples/data/encode-tfbs_v1.0/' + _transcription_factor = 'MAX' + _train_chroms = ['chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr10', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX'] + _val_chroms = ['chr2', 'chr9', 'chr11'] + _test_chroms = ['chr1', 'chr8', 'chr21'] + _all_chroms = _train_chroms + _val_chroms + _test_chroms + _train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562'] + _val_celltype = ['A549'] + _test_celltype = ['GM12878'] + _all_celltypes = _train_celltypes + _val_celltype + _test_celltype + + # Read in metadata dataframe from training+validation data + train_regions_labeled = pd.read_csv(os.path.join(_data_dir, 'labels/{}.train.labels.tsv.gz'.format(_transcription_factor)), sep='\t') + val_regions_labeled = pd.read_csv(os.path.join(_data_dir, 'labels/{}.val.labels.tsv.gz'.format(_transcription_factor)), sep='\t') + training_df = train_regions_labeled# [np.isin(train_regions_labeled['chr'], _train_chroms)] + val_df = val_regions_labeled# [np.isin(val_regions_labeled['chr'], _test_chroms)] + all_df = pd.concat([training_df, val_df]) + + print(time.time() - itime) + + # Get the y values, and remove labels by default. + pd_list = [] + for ct in _all_celltypes: + tc_chr = all_df[['chr', 'start', 'stop', ct]] + tc_chr.columns = ['chr', 'start', 'stop', 'y'] + tc_chr = tc_chr[tc_chr['y'] != 'U'] + tc_chr['y'] = tc_chr['y'].replace({'U': 0, 'B': 1, 'A': 0.5}).values + tc_chr.insert(len(tc_chr.columns), 'celltype', ct) + pd_list.append(tc_chr) + print(ct, time.time() - itime) + _metadata_df = pd.concat(pd_list) + + print(time.time() - itime) + _unsorted_dir = _data_dir + 'labels/MAX/MAX_posamb.bed' + _sorted_dir = _unsorted_dir.replace('MAX_posamb', 'MAX_posamb.sorted') + _metadata_df.to_csv( + _unsorted_dir, sep='\t', header=False, index=False + ) + print(time.time() - itime) + + os.system('sort -k1,1 -k2,2n {} > {}'.format(_unsorted_dir, _sorted_dir)) + + mdf_posamb = pd.read_csv( + _sorted_dir, + sep='\t', header=None, index_col=None, names=['chr', 'start', 'stop', 'y', 'celltype'] + ) + chromsizes_list = [(k, v) for k, v in chrom_sizes.items()] + for ct in _all_celltypes: + ct_labels_bw_path = _data_dir + "labels/MAX/MAX_{}.bigwig".format(ct) + df = mdf_posamb[mdf_posamb['celltype'] == ct] + bw = pyBigWig.open(ct_labels_bw_path, "w") + bw.addHeader(chromsizes_list) + bw.addEntries(list(df['chr']), list(df['start']), ends=list(df['start']+50), values=list(df['y'])) + print(ct, time.time() - itime) + bw.close() diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index 144954b6..7731dbd0 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -112,7 +112,7 @@ 'train_transform': None, 'eval_transform': None, 'loss_function': 'cross_entropy', - 'groupby_fields': ['celltype', 'y'], + 'groupby_fields': ['celltype'], 'val_metric': 'acc_avg', 'val_metric_decreasing': False, 'optimizer': 'Adam', diff --git a/examples/sbox_run_expt.ipynb b/examples/sbox_run_expt.ipynb index 06440dc6..071a68d7 100644 --- a/examples/sbox_run_expt.ipynb +++ b/examples/sbox_run_expt.ipynb @@ -33,43 +33,61 @@ { "cell_type": "code", "execution_count": 1, - "metadata": {}, + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + } + }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "163 µs ± 343 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n" + "ename": "NameError", + "evalue": "name 'bw' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# import pyBigWig\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;31m# %timeit bw = pyBigWig.open(\"/users/abalsubr/wilds/examples/data/encode-tfbs_v1.0/DNASE.K562.fc.signal.bigwig\")\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mget_ipython\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_line_magic\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'timeit'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"bw.values('chr1', 10000, 22800, numpy=True)\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/site-packages/IPython/core/interactiveshell.py\u001b[0m in \u001b[0;36mrun_line_magic\u001b[0;34m(self, magic_name, line, _stack_depth)\u001b[0m\n\u001b[1;32m 2334\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'local_ns'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_local_scope\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstack_depth\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2335\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbuiltin_trap\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2336\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2337\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2338\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m\u001b[0m in \u001b[0;36mtimeit\u001b[0;34m(self, line, cell, local_ns)\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/site-packages/IPython/core/magic.py\u001b[0m in \u001b[0;36m\u001b[0;34m(f, *a, **k)\u001b[0m\n\u001b[1;32m 185\u001b[0m \u001b[0;31m# but it's overkill for just that one bit of state.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 186\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mmagic_deco\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 187\u001b[0;31m \u001b[0mcall\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 188\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 189\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mcallable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/site-packages/IPython/core/magics/execution.py\u001b[0m in \u001b[0;36mtimeit\u001b[0;34m(self, line, cell, local_ns)\u001b[0m\n\u001b[1;32m 1167\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mindex\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1168\u001b[0m \u001b[0mnumber\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m10\u001b[0m \u001b[0;34m**\u001b[0m \u001b[0mindex\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1169\u001b[0;31m \u001b[0mtime_number\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtimer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtimeit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnumber\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1170\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtime_number\u001b[0m \u001b[0;34m>=\u001b[0m \u001b[0;36m0.2\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1171\u001b[0m \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/site-packages/IPython/core/magics/execution.py\u001b[0m in \u001b[0;36mtimeit\u001b[0;34m(self, number)\u001b[0m\n\u001b[1;32m 167\u001b[0m \u001b[0mgc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdisable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 168\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 169\u001b[0;31m \u001b[0mtiming\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minner\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mit\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtimer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 170\u001b[0m \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 171\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mgcold\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m\u001b[0m in \u001b[0;36minner\u001b[0;34m(_it, _timer)\u001b[0m\n", + "\u001b[0;31mNameError\u001b[0m: name 'bw' is not defined" ] } ], "source": [ - "import pyBigWig\n", - "# %timeit bw = pyBigWig.open(\"/users/abalsubr/wilds/examples/data/encode-tfbs_v1.0/DNASE.K562.fc.signal.bigwig\")" + "# import pyBigWig\n", + "# %timeit bw = pyBigWig.open(\"/users/abalsubr/wilds/examples/data/encode-tfbs_v1.0/DNASE.K562.fc.signal.bigwig\")\n", + "%timeit bw.values('chr1', 10000, 22800, numpy=True)" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "WARNING:root:The OGB package is out of date. Your version is 1.2.4, while the latest version is 1.2.5.\n" + "WARNING:root:The OGB package is out of date. Your version is 1.2.4, while the latest version is 1.2.6.\n" ] } ], "source": [ - "import os, csv\n", + "import os, csv, sys\n", + "os.environ['CUDA_VISIBLE_DEVICES'] = '4'\n", + "\n", "import time\n", "import argparse\n", "import numpy as np, pandas as pd\n", "import torch\n", "import torch.nn as nn\n", "import torchvision\n", - "import sys\n", + "import pyBigWig\n", "from collections import defaultdict\n", "\n", "from wilds.common.data_loaders import get_train_loader, get_eval_loader\n", @@ -85,16 +103,16 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" + "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" ] }, - "execution_count": 3, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } @@ -189,7 +207,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -207,7 +225,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -236,7 +254,7 @@ "Resize scale: None\n", "Max token length: None\n", "Loss function: cross_entropy\n", - "Groupby fields: ['celltype', 'y']\n", + "Groupby fields: ['celltype']\n", "Group dro step size: None\n", "Coral penalty weight: None\n", "Irm lambda: None\n", @@ -270,41 +288,19 @@ "Progress bar: False\n", "Resume: False\n", "\n", - "chr2 3.6633927822113037\n", - "chr3 6.7115819454193115\n", - "chr4 9.648637771606445\n", - "chr5 12.439441919326782\n", - "chr6 15.091757774353027\n", - "chr7 17.542895555496216\n", - "chr9 19.707583904266357\n", - "chr10 21.79905652999878\n", - "chr11 23.86957049369812\n", - "chr12 25.918642044067383\n", - "chr13 27.675577402114868\n", - "chr14 29.3148353099823\n", - "chr15 30.881144046783447\n", - "chr16 32.271193504333496\n", - "chr17 33.51785063743591\n", - "chr18 34.72123050689697\n", - "chr19 35.627156257629395\n", - "chr20 36.59872794151306\n", - "chr22 37.37847852706909\n", - "chrX 39.77280807495117\n", - "chr1 43.60475468635559\n", - "chr8 45.86070203781128\n", - "chr21 46.59553360939026\n" - ] - }, - { - "ename": "NameError", - "evalue": "name '_all_celltypes' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0;31m# Data\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 26\u001b[0;31m full_dataset = supported.datasets[config.dataset](\n\u001b[0m\u001b[1;32m 27\u001b[0m \u001b[0mroot_dir\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mroot_dir\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0mdownload\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdownload\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/wilds/wilds/datasets/encodetfbs_dataset.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, root_dir, download, split_scheme)\u001b[0m\n\u001b[1;32m 101\u001b[0m \u001b[0;31m# Get the y values, and remove ambiguous labels by default.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 102\u001b[0m \u001b[0mpd_list\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 103\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mct\u001b[0m \u001b[0;32min\u001b[0m \u001b[0m_all_celltypes\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 104\u001b[0m \u001b[0mtc_chr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mall_df\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'chr'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'start'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'stop'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mct\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 105\u001b[0m \u001b[0mtc_chr\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcolumns\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m'chr'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'start'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'stop'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'y'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mNameError\u001b[0m: name '_all_celltypes' is not defined" + "chr3 5.088634967803955\n", + "chr4 9.974164009094238\n", + "chr5 15.149149894714355\n", + "chr6 19.728455066680908\n", + "chr7 23.769655466079712\n", + "chr10 29.31521511077881\n", + "chr12 32.78225326538086\n", + "chr13 35.67028570175171\n", + "chr14 46.721638441085815\n", + "chr15 92.16564106941223\n", + "chr16 96.26218318939209\n", + "chr17 114.85105729103088\n", + "chr18 116.09504199028015\n" ] } ], @@ -354,20 +350,13 @@ }, { "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" + "execution_count": 5, + "metadata": { + "jupyter": { + "source_hidden": true } - ], + }, + "outputs": [], "source": [ "import copy\n", "full_dataset_camelyon17 = copy.deepcopy(full_dataset)\n", @@ -376,56 +365,58 @@ "# print(config_camelyon.train_transform, config_encode.train_transform)\n" ] }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "ename": "NameError", - "evalue": "name 'full_dataset' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mfull_dataset\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;31mNameError\u001b[0m: name 'full_dataset' is not defined" - ] - } - ], - "source": [ - "full_dataset" - ] - }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## 2) Initialize dataset object" + "## 2) Initialize dataset object (trial version)" ] }, { "cell_type": "code", "execution_count": 8, - "metadata": {}, + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true, + "source_hidden": true + } + }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "chr2 3.90254282951355\n", - "chr9 6.149690628051758\n", - "chr11 8.327073097229004\n", - "chr1 12.291624546051025\n", - "chr8 14.624409675598145\n", - "chr21 15.413429021835327\n", - "H1-hESC 15.415196895599365\n", - "HCT116 15.415941953659058\n", - "HeLa-S3 15.416455030441284\n", - "HepG2 15.417592763900757\n", - "K562 15.418397426605225\n", - "A549 15.41891360282898\n", - "GM12878 15.419732332229614\n" + "chr3 3.0872416496276855\n", + "chr4 6.014077425003052\n", + "chr5 8.789116859436035\n", + "chr6 11.409600496292114\n", + "chr7 13.844907283782959\n", + "chr10 15.919893264770508\n", + "chr12 17.969276189804077\n", + "chr13 19.71941637992859\n", + "chr14 21.34366464614868\n", + "chr15 22.900768995285034\n", + "chr16 24.27766728401184\n", + "chr17 25.519333600997925\n", + "chr18 26.714667797088623\n", + "chr19 27.614336490631104\n", + "chr20 28.57899522781372\n", + "chr22 29.353068113327026\n", + "chrX 31.731130599975586\n", + "chr2 35.449124813079834\n", + "chr9 37.5920934677124\n", + "chr11 39.65406608581543\n", + "chr1 43.44736051559448\n", + "chr8 45.68234419822693\n", + "chr21 46.41120982170105\n", + "H1-hESC 46.41424226760864\n", + "HCT116 46.41492676734924\n", + "HeLa-S3 46.41563010215759\n", + "HepG2 46.41687893867493\n", + "K562 46.41777992248535\n", + "A549 46.41860294342041\n", + "GM12878 46.41955780982971\n" ] } ], @@ -446,18 +437,18 @@ "_dataset_name = 'encode-tfbs'\n", "_version = '1.0'\n", "_download_url = 'https://worksheets.codalab.org/rest/bundles/0x8b3255e21e164cd98d3aeec09cd0bc26/contents/blob/'\n", - "_data_dir = 'data/encode-tfbs_v1.0'\n", + "_data_dir = 'data/encode-tfbs_v1.0/'\n", "_y_size = 1\n", "_n_classes = 2\n", "\n", - "# _train_chroms = ['chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX']\n", - "_train_chroms = ['chr2', 'chr9', 'chr11']\n", + "_train_chroms = ['chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr10', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX']\n", + "_val_chroms = ['chr2', 'chr9', 'chr11']\n", "_test_chroms = ['chr1', 'chr8', 'chr21']\n", "_transcription_factor = 'MAX'\n", "_train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']\n", "_val_celltype = ['A549']\n", "_test_celltype = ['GM12878']\n", - "_all_chroms = _train_chroms + _test_chroms\n", + "_all_chroms = _train_chroms + _val_chroms + _test_chroms\n", "_all_celltypes = _train_celltypes + _val_celltype + _test_celltype\n", "\n", "_metadata_map = {}\n", @@ -486,7 +477,7 @@ "sequence_filename = os.path.join(_data_dir, 'sequence.npz')\n", "seq_arr = np.load(sequence_filename)\n", "_seq_bp = {}\n", - "for chrom in _all_chroms: #seq_arr:\n", + "for chrom in _all_chroms:\n", " _seq_bp[chrom] = seq_arr[chrom]\n", " print(chrom, time.time() - itime)\n", "\n", @@ -504,533 +495,75 @@ " \"\"\"\n", " dnase_bw_path = os.path.join(_data_dir, 'Leopard_dnase/{}.bigwig'.format(ct))\n", " _dnase_allcelltypes[ct] = pyBigWig.open(dnase_bw_path)\n", - " print(ct, time.time() - itime)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "74.06488299369812\n" - ] - } - ], - "source": [ - "itime = time.time()\n", - "\n", - "# Read in metadata dataframe from training+validation data\n", - "train_regions_labeled = pd.read_csv(os.path.join(_data_dir, 'labels/{}.train.labels.tsv.gz'.format(_transcription_factor)), sep='\\t')\n", - "val_regions_labeled = pd.read_csv(os.path.join(_data_dir, 'labels/{}.val.labels.tsv.gz'.format(_transcription_factor)), sep='\\t')\n", - "training_df = train_regions_labeled[np.isin(train_regions_labeled['chr'], _train_chroms)]\n", - "val_df = val_regions_labeled[np.isin(val_regions_labeled['chr'], _test_chroms)]\n", - "all_df = pd.concat([training_df, val_df])\n", - "\n", - "print(time.time() - itime)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - ":8: SettingWithCopyWarning: \n", - "A value is trying to be set on a copy of a slice from a DataFrame.\n", - "Try using .loc[row_indexer,col_indexer] = value instead\n", - "\n", - "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", - " tc_chr['y'] = tc_chr['y'].replace({'U': 0, 'B': 1, 'A': 0.5}).values\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "13.203907012939453\n", - "22.746795415878296\n", - "32.076903104782104\n", - "41.43525505065918\n", - "51.37267017364502\n", - "61.07364773750305\n", - "71.10506510734558\n", - "100.72125267982483\n" - ] - } - ], - "source": [ - "itime = time.time()\n", - "\n", - "# Get the y values, and remove ambiguous labels by default.\n", - "pd_list = []\n", - "for ct in _all_celltypes:\n", - " tc_chr = all_df[['chr', 'start', 'stop', ct]]\n", - " tc_chr.columns = ['chr', 'start', 'stop', 'y']\n", - " tc_chr['y'] = tc_chr['y'].replace({'U': 0, 'B': 1, 'A': 0.5}).values\n", - " \n", - " # # Now filter out ambiguous labels\n", - " # non_ambig_mask = (y_array != -1)\n", - " # tc_chr['y'] = y_array\n", - " # tc_chr = tc_chr[non_ambig_mask]\n", - " \n", - " tc_chr.insert(len(tc_chr.columns), 'celltype', ct)\n", - " pd_list.append(tc_chr)\n", - " print(time.time() - itime)\n", - "metadata_df = pd.concat(pd_list)\n", - "\n", - "print(time.time() - itime)\n", + " print(ct, time.time() - itime)\n", "\n", - "_metadata_df = metadata_df\n" + "_metadata_df = pd.read_csv(\n", + " _data_dir + 'labels/MAX/metadata_df.bed', sep='\\t', header=None, \n", + " index_col=None, names=['chr', 'start', 'stop', 'celltype']\n", + ")" ] }, { "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "H1-hESC 8.80908489227295\n", - "H1-hESC 12.474784135818481\n", - "H1-hESC 15.258162498474121\n", - "HCT116 23.023517370224\n", - "HCT116 25.439095735549927\n", - "HCT116 27.21790099143982\n", - "HeLa-S3 34.81065845489502\n", - "HeLa-S3 35.2776780128479\n", - "HeLa-S3 36.60090255737305\n", - "HepG2 44.36072087287903\n", - "HepG2 44.74501991271973\n", - "HepG2 46.02569603919983\n", - "K562 53.825233697891235\n", - "K562 54.182188749313354\n", - "K562 55.44522547721863\n", - "A549 62.980581283569336\n", - "A549 63.34522008895874\n", - "A549 64.59721446037292\n", - "GM12878 72.41460752487183\n", - "GM12878 72.7955391407013\n", - "GM12878 74.05369997024536\n" - ] + "execution_count": 325, + "metadata": { + "jupyter": { + "source_hidden": true } - ], - "source": [ - "# np.unique(_metadata_df['y'])\n", - "\n", - "# Downsample negatives to balance each celltype\n", - "samp_ndces = []\n", - "itime = time.time()\n", - "neg_msk = (_metadata_df['y'] == 0)\n", - "pos_msk = (_metadata_df['y'] != 0)\n", - "for ct in _all_celltypes:\n", - " celltype_msk = (_metadata_df['celltype'] == ct)\n", - " print(ct, time.time() - itime)\n", - " neg_ct_msk = np.logical_and(celltype_msk, neg_msk)\n", - " pos_ct_msk = np.logical_and(celltype_msk, pos_msk)\n", - " print(ct, time.time() - itime)\n", - " neg_ndces = np.where(neg_ct_msk)[0]\n", - " pos_ndces = np.where(pos_ct_msk)[0]\n", - " np.random.seed(42)\n", - " samp_neg_ndces = np.random.choice(neg_ndces, size=len(pos_ndces), replace=False)\n", - " samp_ndces.extend(samp_neg_ndces)\n", - " samp_ndces.extend(pos_ndces)\n", - " print(ct, time.time() - itime)\n", - "_metadata_df = _metadata_df.iloc[samp_ndces, :]" - ] - }, - { - "cell_type": "code", - "execution_count": 145, - "metadata": {}, + }, "outputs": [], "source": [ - "def get_random_label_vec(metadata_df, output_size=128):\n", - " # Sample a positively labeled region at random\n", - " pos_mdf = metadata_df[metadata_df['y'] == 1] #.iloc[ metadata_df['chr'] == s['chr'], : ]\n", - " pos_seed_region = pos_mdf.iloc[np.random.randint(pos_mdf.shape[0])]\n", - "\n", + "def get_random_label_vec(\n", + " metadata_df, seed_chr, seed_celltype, seed_start, output_size=128\n", + "):\n", + " \"\"\"\n", + " Given a coordinate in a celltype, gets the labels of \n", + " the `output_size` 200bp bins from that coordinate onward. \n", + " \"\"\"\n", + " itime = time.time()\n", + " \n", " # Extract regions from this chromosome in this celltype, to get a window of labels from\n", - " chr_msk = np.array(metadata_df['chr']) == pos_seed_region['chr']\n", - " ct_msk = np.array(metadata_df['celltype']) == pos_seed_region['celltype']\n", - " mdf = metadata_df[chr_msk & ct_msk]\n", + " # print(time.time() - itime)\n", + " # chr_msk = np.array(metadata_df['chr']) == seed_region['chr']\n", + " # print(time.time() - itime)\n", + " # ct_msk = np.array(metadata_df['celltype']) == seed_region['celltype']\n", + " # mdf = metadata_df[chr_msk & ct_msk]\n", + " seq_size = output_size*50\n", + " mdf = metadata_df.loc[\n", + " (metadata_df['chr'] == seed_chr) & \n", + " (metadata_df['celltype'] == seed_celltype) & \n", + " (metadata_df['start'] >= seed_start) & \n", + " (metadata_df['stop'] < seed_start+seq_size)\n", + " ]\n", + " print(time.time() - itime)\n", "\n", " # Get labels\n", - " start_ndx = np.where(mdf['start'] == pos_seed_region['start'])[0][0]\n", - " y_label_vec = mdf.iloc[start_ndx:start_ndx+output_size, :]['y']" + " y_label_vec = np.zeros(output_size)\n", + " y_label_vec[(mdf['start'] - seed_start) // 50] = mdf['y']\n", + " return mdf, y_label_vec" ] }, { "cell_type": "code", - "execution_count": 146, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
chrstartstopycelltype
5937924chr86008000.0HepG2
5937925chr86508500.0HepG2
5937926chr87009000.0HepG2
5937927chr87509500.0HepG2
5937928chr880010000.0HepG2
..................
8843006chr81463632001463634000.0HepG2
8843007chr81463632501463634500.0HepG2
8843008chr81463633001463635000.0HepG2
8843009chr81463633501463635500.0HepG2
8843010chr81463634001463636000.0HepG2
\n", - "

2905087 rows × 5 columns

\n", - "
" - ], - "text/plain": [ - " chr start stop y celltype\n", - "5937924 chr8 600 800 0.0 HepG2\n", - "5937925 chr8 650 850 0.0 HepG2\n", - "5937926 chr8 700 900 0.0 HepG2\n", - "5937927 chr8 750 950 0.0 HepG2\n", - "5937928 chr8 800 1000 0.0 HepG2\n", - "... ... ... ... ... ...\n", - "8843006 chr8 146363200 146363400 0.0 HepG2\n", - "8843007 chr8 146363250 146363450 0.0 HepG2\n", - "8843008 chr8 146363300 146363500 0.0 HepG2\n", - "8843009 chr8 146363350 146363550 0.0 HepG2\n", - "8843010 chr8 146363400 146363600 0.0 HepG2\n", - "\n", - "[2905087 rows x 5 columns]" - ] - }, - "execution_count": 146, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 154, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "8809571 1.0\n", - "8809572 1.0\n", - "8809573 1.0\n", - "8809574 1.0\n", - "8809575 1.0\n", - "8809576 0.5\n", - "8809577 0.5\n", - "8809578 0.5\n", - "8809579 1.0\n", - "8809580 1.0\n", - "8809581 1.0\n", - "8809582 1.0\n", - "8809583 1.0\n", - "8809584 1.0\n", - "8809585 0.5\n", - "8809586 0.5\n", - "8809587 0.0\n", - "8809588 0.0\n", - "8809589 0.0\n", - "8809590 0.0\n", - "8809591 0.0\n", - "8809592 0.0\n", - "8809593 0.0\n", - "8809594 0.0\n", - "8809595 0.0\n", - "8809596 0.0\n", - "8809597 0.0\n", - "8809598 0.0\n", - "8809599 0.0\n", - "8809600 0.0\n", - "8809601 0.0\n", - "8809602 0.0\n", - "Name: y, dtype: float64" - ] - }, - "execution_count": 154, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 107, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([0.])" - ] - }, - "execution_count": 107, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "np.unique(mdf[:256+start_bin]['y'])" - ] - }, - { - "cell_type": "code", - "execution_count": 150, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "chr chr8\n", - "start 144691450\n", - "stop 144691650\n", - "y 1.0\n", - "celltype HepG2\n", - "Name: 8809571, dtype: object" - ] - }, - "execution_count": 150, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "pos_seed_region" - ] - }, - { - "cell_type": "code", - "execution_count": 98, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([ 600, 600, 600, ..., 135005900, 135005900,\n", - " 135005900])" - ] - }, - "execution_count": 98, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# arr = metadata_df[mdf_msk]['start']\n", - "#arr == \n", - "np.sort(arr)" - ] - }, - { - "cell_type": "code", - "execution_count": 69, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "116.39193439483643\n" - ] - } - ], - "source": [ - "itime = time.time()\n", - "lts = ['{}:{}-{}'.format(x[0], x[1], x[2]) for x in zip(metadata_df['chr'], metadata_df['start'], metadata_df['stop'])]\n", - "print(time.time() - itime)" - ] - }, - { - "cell_type": "code", - "execution_count": 55, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "202800" - ] - }, - "execution_count": 55, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "#_metadata_df['y']\n", - "# s = metadata_df.iloc[np.array(pos_msk), :]\n", - "ntry = s.iloc[5]\n", - "ntry['start'] + 12800\n", - "# s['chr'], s['start'], s['stop'] # np.unique(s['chr'], return_counts=True)\n", - "# all_df\n", - "# metadata_df" - ] - }, - { - "cell_type": "code", - "execution_count": 16, + "execution_count": 13, "metadata": { - "collapsed": true, "jupyter": { - "outputs_hidden": true + "source_hidden": true } }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "H1-hESC 7.871182918548584\n", - "H1-hESC 8.298529148101807\n", - "H1-hESC 9.57175898551941\n", - "HCT116 17.01794719696045\n", - "HCT116 17.36267113685608\n", - "HCT116 18.669682025909424\n", - "HeLa-S3 26.405478954315186\n", - "HeLa-S3 26.759119272232056\n", - "HeLa-S3 28.043395042419434\n", - "HepG2 35.623862981796265\n", - "HepG2 35.98245143890381\n", - "HepG2 37.29869079589844\n", - "K562 44.92080807685852\n", - "K562 45.256179332733154\n", - "K562 46.7364935874939\n", - "A549 54.39264512062073\n", - "A549 54.74424934387207\n", - "A549 56.03351712226868\n", - "GM12878 63.745240211486816\n", - "GM12878 64.1029920578003\n", - "GM12878 65.43286633491516\n" - ] - } - ], + "outputs": [], "source": [ "train_regions_mask = np.isin(_metadata_df['chr'], _train_chroms)\n", - "val_regions_mask = np.isin(_metadata_df['chr'], _test_chroms)\n", + "val_regions_mask = np.isin(_metadata_df['chr'], _val_chroms)\n", + "test_regions_mask = np.isin(_metadata_df['chr'], _test_chroms)\n", "train_celltype_mask = np.isin(_metadata_df['celltype'], _train_celltypes)\n", "val_celltype_mask = np.isin(_metadata_df['celltype'], _val_celltype)\n", "test_celltype_mask = np.isin(_metadata_df['celltype'], _test_celltype)\n", "\n", "split_array = -1*np.ones(_metadata_df.shape[0]).astype(int)\n", "split_array[np.logical_and(train_regions_mask, train_celltype_mask)] = _split_dict['train']\n", - "split_array[np.logical_and(val_regions_mask, test_celltype_mask)] = _split_dict['test']\n", - "# Validate using test chr, either using a designated validation cell line ('val') or a training cell line ('id_val')\n", + "split_array[np.logical_and(test_regions_mask, test_celltype_mask)] = _split_dict['test']\n", + "# Validate using validation chr, either using a designated validation cell line ('val') or a training cell line ('id_val')\n", "split_array[np.logical_and(val_regions_mask, val_celltype_mask)] = _split_dict['val']\n", "split_array[np.logical_and(val_regions_mask, train_celltype_mask)] = _split_dict['id_val']\n", "\n", @@ -1039,19 +572,22 @@ "else:\n", " raise ValueError(f'Split scheme {_split_scheme} not recognized')\n", "\n", + "metadata_mask = (_metadata_df['split'] != -1)\n", "_metadata_df = _metadata_df[_metadata_df['split'] != -1]\n", - "_split_array = _metadata_df['split'].values\n", "\n", "chr_ints = _metadata_df['chr'].replace(dict( [(y, x) for x, y in enumerate(_metadata_map['chr'])] )).values\n", "celltype_ints = _metadata_df['celltype'].replace(dict( [(y, x) for x, y in enumerate(_metadata_map['celltype'])] )).values\n", - "_y_array = torch.LongTensor(np.array(_metadata_df['y']))\n", + "_split_array = _metadata_df['split'].values\n", + "\n", + "_y_array = torch.Tensor(np.load(_data_dir + 'labels/MAX/metadata_y.npy'))\n", + "_y_array = _y_array[metadata_mask]\n", "\n", "_metadata_array = torch.stack(\n", " (torch.LongTensor(chr_ints), \n", - " torch.LongTensor(celltype_ints), \n", - " _y_array),\n", + " torch.LongTensor(celltype_ints)\n", + " ),\n", " dim=1)\n", - "_metadata_fields = ['chr', 'celltype', 'y']" + "_metadata_fields = ['chr', 'celltype']" ] }, { @@ -1063,8 +599,12 @@ }, { "cell_type": "code", - "execution_count": 3, - "metadata": {}, + "execution_count": 24, + "metadata": { + "jupyter": { + "source_hidden": true + } + }, "outputs": [], "source": [ "import os, time\n", @@ -1099,17 +639,17 @@ " self._version = '1.0'\n", " self._download_url = 'https://worksheets.codalab.org/rest/bundles/0x8b3255e21e164cd98d3aeec09cd0bc26/contents/blob/'\n", " self._data_dir = self.initialize_data_dir(root_dir, download)\n", - " self._y_size = 1\n", - " self._n_classes = 2\n", + " self._y_size = 128\n", + " # self._n_classes = 2\n", " \n", - " # self._train_chroms = ['chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX']\n", - " self._train_chroms = ['chr2', 'chr9', 'chr11']\n", + " self._train_chroms = ['chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr10', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX']\n", + " self._val_chroms = ['chr2', 'chr9', 'chr11']\n", " self._test_chroms = ['chr1', 'chr8', 'chr21']\n", " self._transcription_factor = 'MAX'\n", " self._train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']\n", " self._val_celltype = ['A549']\n", " self._test_celltype = ['GM12878']\n", - " self._all_chroms = self._train_chroms + self._test_chroms\n", + " self._all_chroms = self._train_chroms + self._val_chroms + self._test_chroms\n", " self._all_celltypes = self._train_celltypes + self._val_celltype + self._test_celltype\n", " \n", " self._metadata_map = {}\n", @@ -1143,6 +683,9 @@ " print(chrom, time.time() - itime)\n", " \n", " self._dnase_allcelltypes = {}\n", + " ct = 'avg'\n", + " dnase_avg_bw_path = os.path.join(self._data_dir, 'Leopard_dnase/{}.bigwig'.format(ct))\n", + " self._dnase_allcelltypes[ct] = pyBigWig.open(dnase_avg_bw_path)\n", " for ct in self._all_celltypes:\n", " \"\"\"\n", " dnase_filename = os.path.join(self._data_dir, '{}_dnase.npz'.format(ct))\n", @@ -1151,63 +694,25 @@ " for chrom in self._all_chroms: #self._seq_bp:\n", " self._dnase_allcelltypes[ct][chrom] = dnase_npz_contents[chrom]\n", " \"\"\"\n", - " self._dnase_allcelltypes[ct] = os.path.join(self._data_dir, 'DNASE.{}.fc.signal.bigwig'.format(ct))\n", - " print(ct, time.time() - itime)\n", - " \n", - " # Read in metadata dataframe from training+validation data\n", - " train_regions_labeled = pd.read_csv(os.path.join(self._data_dir, 'labels/{}.train.labels.tsv.gz'.format(self._transcription_factor)), sep='\\t')\n", - " val_regions_labeled = pd.read_csv(os.path.join(self._data_dir, 'labels/{}.val.labels.tsv.gz'.format(self._transcription_factor)), sep='\\t')\n", - " training_df = train_regions_labeled[np.isin(train_regions_labeled['chr'], self._train_chroms)]\n", - " val_df = val_regions_labeled[np.isin(val_regions_labeled['chr'], self._test_chroms)]\n", - " all_df = pd.concat([training_df, val_df])\n", + " dnase_bw_path = os.path.join(self._data_dir, 'Leopard_dnase/{}.bigwig'.format(ct))\n", + " self._dnase_allcelltypes[ct] = pyBigWig.open(dnase_bw_path)\n", " \n", - " # Get the y values, and remove ambiguous labels by default.\n", - " pd_list = []\n", - " for ct in _all_celltypes:\n", - " tc_chr = all_df[['chr', 'start', 'stop', ct]]\n", - " tc_chr.columns = ['chr', 'start', 'stop', 'y']\n", - " y_array = tc_chr['y'].replace({'U': 0, 'B': 1, 'A': -1}).values\n", - "\n", - " # Now filter out ambiguous labels\n", - " non_ambig_mask = (y_array != -1)\n", - " tc_chr['y'] = y_array\n", - " tc_chr = tc_chr[non_ambig_mask]\n", - "\n", - " tc_chr.insert(len(tc_chr.columns), 'celltype', ct)\n", - " pd_list.append(tc_chr)\n", - " print(time.time() - itime)\n", - " self._metadata_df = pd.concat(pd_list)\n", - " \n", - " # Downsample negatives to balance each celltype\n", - " samp_ndces = []\n", - " itime = time.time()\n", - " neg_msk = (self._metadata_df['y'] == 0)\n", - " pos_msk = (self._metadata_df['y'] == 1)\n", - " for ct in _all_celltypes:\n", - " celltype_msk = (self._metadata_df['celltype'] == ct)\n", - " print(ct, time.time() - itime)\n", - " neg_ct_msk = np.logical_and(celltype_msk, neg_msk)\n", - " pos_ct_msk = np.logical_and(celltype_msk, pos_msk)\n", - " print(ct, time.time() - itime)\n", - " neg_ndces = np.where(neg_ct_msk)[0]\n", - " pos_ndces = np.where(pos_ct_msk)[0]\n", - " np.random.seed(42)\n", - " samp_neg_ndces = np.random.choice(neg_ndces, size=len(pos_ndces), replace=False)\n", - " samp_ndces.extend(samp_neg_ndces)\n", - " samp_ndces.extend(pos_ndces)\n", - " print(ct, time.time() - itime)\n", - " self._metadata_df = self._metadata_df.iloc[samp_ndces, :]\n", + " self._metadata_df = pd.read_csv(\n", + " self._data_dir + '/labels/MAX/metadata_df.bed', sep='\\t', header=None, \n", + " index_col=None, names=['chr', 'start', 'stop', 'celltype']\n", + " )\n", " \n", " train_regions_mask = np.isin(self._metadata_df['chr'], self._train_chroms)\n", - " val_regions_mask = np.isin(self._metadata_df['chr'], self._test_chroms)\n", + " val_regions_mask = np.isin(self._metadata_df['chr'], self._val_chroms)\n", + " test_regions_mask = np.isin(self._metadata_df['chr'], self._test_chroms)\n", " train_celltype_mask = np.isin(self._metadata_df['celltype'], self._train_celltypes)\n", " val_celltype_mask = np.isin(self._metadata_df['celltype'], self._val_celltype)\n", " test_celltype_mask = np.isin(self._metadata_df['celltype'], self._test_celltype)\n", " \n", " split_array = -1*np.ones(self._metadata_df.shape[0]).astype(int)\n", " split_array[np.logical_and(train_regions_mask, train_celltype_mask)] = self._split_dict['train']\n", - " split_array[np.logical_and(val_regions_mask, test_celltype_mask)] = self._split_dict['test']\n", - " # Validate using test chr, either using a designated validation cell line ('val') or a training cell line ('id_val')\n", + " split_array[np.logical_and(test_regions_mask, test_celltype_mask)] = self._split_dict['test']\n", + " # Validate using validation chr, either using a designated validation cell line ('val') or a training cell line ('id_val')\n", " split_array[np.logical_and(val_regions_mask, val_celltype_mask)] = self._split_dict['val']\n", " split_array[np.logical_and(val_regions_mask, train_celltype_mask)] = self._split_dict['id_val']\n", " \n", @@ -1216,19 +721,21 @@ " else:\n", " raise ValueError(f'Split scheme {self._split_scheme} not recognized')\n", " \n", + " metadata_mask = (self._metadata_df['split'] != -1)\n", " self._metadata_df = self._metadata_df[self._metadata_df['split'] != -1]\n", - " self._split_array = self._metadata_df['split'].values\n", " \n", " chr_ints = self._metadata_df['chr'].replace(dict( [(y, x) for x, y in enumerate(self._metadata_map['chr'])] )).values\n", " celltype_ints = self._metadata_df['celltype'].replace(dict( [(y, x) for x, y in enumerate(self._metadata_map['celltype'])] )).values\n", - " self._y_array = torch.LongTensor(np.array(self._metadata_df['y']))\n", + " self._split_array = self._metadata_df['split'].values\n", + " self._y_array = torch.Tensor(np.load(self._data_dir + '/labels/MAX/metadata_y.npy'))\n", + " self._y_array = self._y_array[metadata_mask]\n", " \n", " self._metadata_array = torch.stack(\n", " (torch.LongTensor(chr_ints), \n", - " torch.LongTensor(celltype_ints), \n", - " self._y_array),\n", + " torch.LongTensor(celltype_ints)\n", + " ),\n", " dim=1)\n", - " self._metadata_fields = ['chr', 'celltype', 'y']\n", + " self._metadata_fields = ['chr', 'celltype']\n", " \n", " self._eval_grouper = CombinatorialGrouper(\n", " dataset=self,\n", @@ -1237,33 +744,43 @@ " self._metric = Accuracy()\n", " \n", " super().__init__(root_dir, download, split_scheme)\n", - "\n", - " def get_input(self, idx):\n", + " \n", + " \"\"\"\n", + " def get_random_label_vec(metadata_df, output_size=128):\n", + " # Sample a positively labeled region at random\n", + " pos_mdf = metadata_df[metadata_df['y'] == 1] #.iloc[ metadata_df['chr'] == s['chr'], : ]\n", + " pos_seed_region = pos_mdf.iloc[np.random.randint(pos_mdf.shape[0])]\n", + "\n", + " # Extract regions from this chromosome in this celltype, to get a window of labels from\n", + " chr_msk = np.array(metadata_df['chr']) == pos_seed_region['chr']\n", + " ct_msk = np.array(metadata_df['celltype']) == pos_seed_region['celltype']\n", + " mdf = metadata_df[chr_msk & ct_msk]\n", + "\n", + " # Get labels\n", + " start_ndx = np.where(mdf['start'] == pos_seed_region['start'])[0][0]\n", + " y_label_vec = mdf.iloc[start_ndx:start_ndx+output_size, :]['y']\n", + " \"\"\"\n", + " \n", + " def get_input(self, idx, window_size=12800):\n", " \"\"\"\n", " Returns x for a given idx in metadata_array, which has been filtered to only take windows with the desired stride.\n", " Computes this from: \n", " (1) sequence features in self._seq_bp\n", - " (2) DNase bigwig file paths in self._dnase_allcelltypes\n", - " (3) Metadata for the index (location along the genome with 200bp window width)\n", + " (2) DNase bigwig file handles in self._dnase_allcelltypes\n", + " (3) Metadata for the index (location along the genome with 6400bp window width)\n", + " (4) Window_size, the length of sequence returned (centered on the 6400bp region in (3))\n", " \"\"\"\n", - " \n", " this_metadata = self._metadata_df.iloc[idx, :]\n", - " \"\"\"\n", - " flank_size = 400\n", - " interval_start = this_metadata['start'] - flank_size\n", - " interval_end = this_metadata['stop'] + flank_size\n", - " dnase_this = self._dnase_allcelltypes[this_metadata['celltype']][this_metadata['chr']][interval_start:interval_end]\n", - " seq_this = self._seq_bp[this_metadata['chr']][interval_start:interval_end]\n", - " return torch.tensor(np.column_stack([seq_this, dnase_this]))\n", - " \"\"\"\n", - " window_size = 12800\n", - " interval_start = this_metadata['start']\n", - " interval_end = this_metadata['stop'] + window_size\n", + " interval_start = this_metadata['start'] - int(window_size/4)\n", + " interval_end = interval_start + window_size #this_metadata['stop']\n", " seq_this = self._seq_bp[this_metadata['chr']][interval_start:interval_end]\n", " dnase_bw = self._dnase_allcelltypes[this_metadata['celltype']]\n", " dnase_this = dnase_bw.values(chrom, interval_start, interval_end, numpy=True)\n", - " return torch.tensor(np.column_stack([seq_this, dnase_this]))\n", - " \n", + " dnase_avg = self._dnase_allcelltypes['avg'].values(chrom, interval_start, interval_end, numpy=True)\n", + " return torch.tensor(np.column_stack(\n", + " [np.nan_to_num(seq_this), np.nan_to_num(dnase_this), np.nan_to_num(dnase_avg)]\n", + " ))\n", + "\n", " def eval(self, y_pred, y_true, metadata):\n", " return self.standard_group_eval(\n", " self._metric,\n", @@ -1273,7 +790,7 @@ }, { "cell_type": "code", - "execution_count": 107, + "execution_count": 26, "metadata": { "collapsed": true, "jupyter": { @@ -1285,40 +802,29 @@ "name": "stdout", "output_type": "stream", "text": [ - "chr2 3.962329387664795\n", - "chr9 6.259538888931274\n", - "chr11 8.446826934814453\n", - "chr1 12.49940538406372\n", - "chr8 14.91869592666626\n", - "chr21 15.700694799423218\n", - "H1-hESC 23.95099449157715\n", - "HCT116 31.26502823829651\n", - "HeLa-S3 39.382277488708496\n", - "HepG2 47.24500226974487\n", - "K562 55.079211711883545\n", - "A549 62.405343532562256\n", - "GM12878 70.00356984138489\n", - "H1-hESC 8.160386562347412\n", - "H1-hESC 8.546203374862671\n", - "H1-hESC 9.868412971496582\n", - "HCT116 17.121587991714478\n", - "HCT116 17.524660110473633\n", - "HCT116 18.90956425666809\n", - "HeLa-S3 26.98938488960266\n", - "HeLa-S3 27.376858234405518\n", - "HeLa-S3 28.7989022731781\n", - "HepG2 36.29348182678223\n", - "HepG2 36.668752908706665\n", - "HepG2 38.151512145996094\n", - "K562 45.96789216995239\n", - "K562 46.33995985984802\n", - "K562 47.87280249595642\n", - "A549 55.380892276763916\n", - "A549 55.75924301147461\n", - "A549 57.22686314582825\n", - "GM12878 65.09361720085144\n", - "GM12878 65.50619888305664\n", - "GM12878 66.9196424484253\n" + "chr3 3.0425407886505127\n", + "chr4 5.967821359634399\n", + "chr5 8.747126340866089\n", + "chr6 11.370141744613647\n", + "chr7 13.802208423614502\n", + "chr10 15.875979900360107\n", + "chr12 17.929850339889526\n", + "chr13 19.67976665496826\n", + "chr14 21.306750059127808\n", + "chr15 22.866544723510742\n", + "chr16 24.241100788116455\n", + "chr17 25.480982303619385\n", + "chr18 26.677065134048462\n", + "chr19 27.579110622406006\n", + "chr20 28.545915603637695\n", + "chr22 29.323810577392578\n", + "chrX 31.698036670684814\n", + "chr2 35.40705943107605\n", + "chr9 37.5518524646759\n", + "chr11 39.61783218383789\n", + "chr1 43.411964893341064\n", + "chr8 45.64823389053345\n", + "chr21 46.377281188964844\n" ] } ], @@ -1331,91 +837,28 @@ ] }, { - "cell_type": "code", - "execution_count": 2, + "cell_type": "markdown", "metadata": {}, - "outputs": [ - { - "ename": "ModuleNotFoundError", - "evalue": "No module named 'pyBigWig'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# full_dataset_encode\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mpyBigWig\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'pyBigWig'" - ] - } - ], "source": [ - "# full_dataset_encode\n", - "import pyBigWig" + "# Initialize algorithm" ] }, { "cell_type": "code", - "execution_count": 39, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "a = np.random.choice(1210796, size=128)\n", - "seta = [full_dataset_encode.get_input(x) for x in a]\n", - "seta[0].shape\n", - "\n", - "# full_dataset = copy.deepcopy(full_dataset_encode)\n", - "# full_dataset = copy.deepcopy(full_dataset_camelyon17)\n", - "# full_dataset_camelyon17.split_dict\n", - "\n", - "# full_dataset" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Initialize algorithm" + "config" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Train data...\n", - " hospital = 0: n = 53425\n", - " hospital = 1: n = 0\n", - " hospital = 2: n = 0\n", - " hospital = 3: n = 116959\n", - " hospital = 4: n = 132052\n", - "Validation (ID) data...\n", - " hospital = 0: n = 6011\n", - " hospital = 1: n = 0\n", - " hospital = 2: n = 0\n", - " hospital = 3: n = 12879\n", - " hospital = 4: n = 14670\n", - "Test data...\n", - " hospital = 0: n = 0\n", - " hospital = 1: n = 0\n", - " hospital = 2: n = 85054\n", - " hospital = 3: n = 0\n", - " hospital = 4: n = 0\n", - "Validation (OOD) data...\n", - " hospital = 0: n = 0\n", - " hospital = 1: n = 34904\n", - " hospital = 2: n = 0\n", - " hospital = 3: n = 0\n", - " hospital = 4: n = 0\n", - "Dout: 2\n" - ] - } - ], + "outputs": [], "source": [ - "config = config_camelyon\n", - "\n", + "# config = config_encode\n", "\n", "train_grouper = CombinatorialGrouper(\n", " dataset=full_dataset,\n", @@ -1488,6 +931,77 @@ " train_grouper=train_grouper)" ] }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Available objects for config:\n", + " AliasManager\n", + " DisplayFormatter\n", + " HistoryManager\n", + " IPCompleter\n", + " IPKernelApp\n", + " LoggingMagics\n", + " MagicsManager\n", + " OSMagics\n", + " PrefilterManager\n", + " ScriptMagics\n", + " StoreMagics\n", + " ZMQInteractiveShell\n" + ] + } + ], + "source": [ + "config" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "device(type='cuda', index=0)" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "config.device" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'np' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfull_dataset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_input\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;31m#\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mNameError\u001b[0m: name 'np' is not defined" + ] + } + ], + "source": [ + "np.array(full_dataset.get_input(0)).shape\n", + "#" + ] + }, { "cell_type": "code", "execution_count": 29, @@ -1545,7 +1059,12 @@ { "cell_type": "code", "execution_count": 30, - "metadata": {}, + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + } + }, "outputs": [ { "data": { @@ -1678,33 +1197,22 @@ }, { "cell_type": "code", - "execution_count": 126, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 126, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "for b in full_dataset:\n", - " break" - ] + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] }, { "cell_type": "code", "execution_count": 33, - "metadata": { - "jupyter": { - "source_hidden": true - } - }, + "metadata": {}, "outputs": [], "source": [ "\n", diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index 04f5d08d..b5657597 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -33,17 +33,17 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): self._version = '1.0' self._download_url = 'https://worksheets.codalab.org/rest/bundles/0x8b3255e21e164cd98d3aeec09cd0bc26/contents/blob/' self._data_dir = self.initialize_data_dir(root_dir, download) - self._y_size = 1 - self._n_classes = 2 + self._y_size = 128 + # self._n_classes = 2 - self._train_chroms = ['chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX'] - # self._train_chroms = ['chr2', 'chr9', 'chr11'] + self._train_chroms = ['chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr10', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX'] + self._val_chroms = ['chr2', 'chr9', 'chr11'] self._test_chroms = ['chr1', 'chr8', 'chr21'] self._transcription_factor = 'MAX' self._train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562'] self._val_celltype = ['A549'] self._test_celltype = ['GM12878'] - self._all_chroms = self._train_chroms + self._test_chroms + self._all_chroms = self._train_chroms + self._val_chroms + self._test_chroms self._all_celltypes = self._train_celltypes + self._val_celltype + self._test_celltype self._metadata_map = {} @@ -91,60 +91,22 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): dnase_bw_path = os.path.join(self._data_dir, 'Leopard_dnase/{}.bigwig'.format(ct)) self._dnase_allcelltypes[ct] = pyBigWig.open(dnase_bw_path) - # Read in metadata dataframe from training+validation data - train_regions_labeled = pd.read_csv(os.path.join(self._data_dir, 'labels/{}.train.labels.tsv.gz'.format(self._transcription_factor)), sep='\t') - val_regions_labeled = pd.read_csv(os.path.join(self._data_dir, 'labels/{}.val.labels.tsv.gz'.format(self._transcription_factor)), sep='\t') - training_df = train_regions_labeled[np.isin(train_regions_labeled['chr'], self._train_chroms)] - val_df = val_regions_labeled[np.isin(val_regions_labeled['chr'], self._test_chroms)] - all_df = pd.concat([training_df, val_df]) - - # Get the y values, and remove ambiguous labels by default. - pd_list = [] - for ct in self._all_celltypes: - tc_chr = all_df[['chr', 'start', 'stop', ct]] - tc_chr.columns = ['chr', 'start', 'stop', 'y'] - y_array = tc_chr['y'].replace({'U': 0, 'B': 1, 'A': -1}).values - - # Now filter out ambiguous labels - non_ambig_mask = (y_array != -1) - tc_chr['y'] = y_array - tc_chr = tc_chr[non_ambig_mask] - - tc_chr.insert(len(tc_chr.columns), 'celltype', ct) - pd_list.append(tc_chr) - print(time.time() - itime) - self._metadata_df = pd.concat(pd_list) - - # Downsample negatives to balance each celltype - samp_ndces = [] - itime = time.time() - neg_msk = (self._metadata_df['y'] == 0) - pos_msk = (self._metadata_df['y'] == 1) - for ct in _all_celltypes: - celltype_msk = (self._metadata_df['celltype'] == ct) - print(ct, time.time() - itime) - neg_ct_msk = np.logical_and(celltype_msk, neg_msk) - pos_ct_msk = np.logical_and(celltype_msk, pos_msk) - print(ct, time.time() - itime) - neg_ndces = np.where(neg_ct_msk)[0] - pos_ndces = np.where(pos_ct_msk)[0] - np.random.seed(42) - samp_neg_ndces = np.random.choice(neg_ndces, size=len(pos_ndces), replace=False) - samp_ndces.extend(samp_neg_ndces) - samp_ndces.extend(pos_ndces) - print(ct, time.time() - itime) - self._metadata_df = self._metadata_df.iloc[samp_ndces, :] + self._metadata_df = pd.read_csv( + self._data_dir + '/labels/MAX/metadata_df.bed', sep='\t', header=None, + index_col=None, names=['chr', 'start', 'stop', 'celltype'] + ) train_regions_mask = np.isin(self._metadata_df['chr'], self._train_chroms) - val_regions_mask = np.isin(self._metadata_df['chr'], self._test_chroms) + val_regions_mask = np.isin(self._metadata_df['chr'], self._val_chroms) + test_regions_mask = np.isin(self._metadata_df['chr'], self._test_chroms) train_celltype_mask = np.isin(self._metadata_df['celltype'], self._train_celltypes) val_celltype_mask = np.isin(self._metadata_df['celltype'], self._val_celltype) test_celltype_mask = np.isin(self._metadata_df['celltype'], self._test_celltype) split_array = -1*np.ones(self._metadata_df.shape[0]).astype(int) split_array[np.logical_and(train_regions_mask, train_celltype_mask)] = self._split_dict['train'] - split_array[np.logical_and(val_regions_mask, test_celltype_mask)] = self._split_dict['test'] - # Validate using test chr, either using a designated validation cell line ('val') or a training cell line ('id_val') + split_array[np.logical_and(test_regions_mask, test_celltype_mask)] = self._split_dict['test'] + # Validate using validation chr, either using a designated validation cell line ('val') or a training cell line ('id_val') split_array[np.logical_and(val_regions_mask, val_celltype_mask)] = self._split_dict['val'] split_array[np.logical_and(val_regions_mask, train_celltype_mask)] = self._split_dict['id_val'] @@ -153,19 +115,21 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): else: raise ValueError(f'Split scheme {self._split_scheme} not recognized') + metadata_mask = (self._metadata_df['split'] != -1) self._metadata_df = self._metadata_df[self._metadata_df['split'] != -1] - self._split_array = self._metadata_df['split'].values chr_ints = self._metadata_df['chr'].replace(dict( [(y, x) for x, y in enumerate(self._metadata_map['chr'])] )).values celltype_ints = self._metadata_df['celltype'].replace(dict( [(y, x) for x, y in enumerate(self._metadata_map['celltype'])] )).values - self._y_array = torch.LongTensor(np.array(self._metadata_df['y'])) + self._split_array = self._metadata_df['split'].values + self._y_array = torch.Tensor(np.load(self._data_dir + '/labels/MAX/metadata_y.npy')) + self._y_array = self._y_array[metadata_mask] self._metadata_array = torch.stack( (torch.LongTensor(chr_ints), - torch.LongTensor(celltype_ints), - self._y_array), + torch.LongTensor(celltype_ints) + ), dim=1) - self._metadata_fields = ['chr', 'celltype', 'y'] + self._metadata_fields = ['chr', 'celltype'] self._eval_grouper = CombinatorialGrouper( dataset=self, @@ -174,7 +138,8 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): self._metric = Accuracy() super().__init__(root_dir, download, split_scheme) - + + """ def get_random_label_vec(metadata_df, output_size=128): # Sample a positively labeled region at random pos_mdf = metadata_df[metadata_df['y'] == 1] #.iloc[ metadata_df['chr'] == s['chr'], : ] @@ -188,33 +153,27 @@ def get_random_label_vec(metadata_df, output_size=128): # Get labels start_ndx = np.where(mdf['start'] == pos_seed_region['start'])[0][0] y_label_vec = mdf.iloc[start_ndx:start_ndx+output_size, :]['y'] + """ - def get_input(self, idx): + def get_input(self, idx, window_size=12800): """ Returns x for a given idx in metadata_array, which has been filtered to only take windows with the desired stride. Computes this from: (1) sequence features in self._seq_bp (2) DNase bigwig file handles in self._dnase_allcelltypes - (3) Metadata for the index (location along the genome with 200bp window width) + (3) Metadata for the index (location along the genome with 6400bp window width) + (4) Window_size, the length of sequence returned (centered on the 6400bp region in (3)) """ - this_metadata = self._metadata_df.iloc[idx, :] - """ - flank_size = 400 - interval_start = this_metadata['start'] - flank_size - interval_end = this_metadata['stop'] + flank_size - dnase_this = self._dnase_allcelltypes[this_metadata['celltype']][this_metadata['chr']][interval_start:interval_end] - seq_this = self._seq_bp[this_metadata['chr']][interval_start:interval_end] - return torch.tensor(np.column_stack([seq_this, dnase_this])) - """ - window_size = 12800 - interval_start = this_metadata['start'] - interval_end = window_size + interval_start #this_metadata['stop'] + interval_start = this_metadata['start'] - int(window_size/4) + interval_end = interval_start + window_size #this_metadata['stop'] seq_this = self._seq_bp[this_metadata['chr']][interval_start:interval_end] dnase_bw = self._dnase_allcelltypes[this_metadata['celltype']] dnase_this = dnase_bw.values(chrom, interval_start, interval_end, numpy=True) dnase_avg = self._dnase_allcelltypes['avg'].values(chrom, interval_start, interval_end, numpy=True) - return torch.tensor(np.column_stack([seq_this, dnase_this, dnase_avg])) + return torch.tensor(np.column_stack( + [np.nan_to_num(seq_this), np.nan_to_num(dnase_this), np.nan_to_num(dnase_avg)] + )) def eval(self, y_pred, y_true, metadata): return self.standard_group_eval( From 75d910fcf0ca5eb38774bc038b0cd6283afaf715 Mon Sep 17 00:00:00 2001 From: aikanor Date: Fri, 5 Mar 2021 10:26:05 -0800 Subject: [PATCH 074/244] final code (1/3) --- examples/models/CNN_genome.py | 4 +- examples/sbox_run_expt.ipynb | 671 +++++++++++++++++++++++---- wilds/datasets/encodetfbs_dataset.py | 4 +- 3 files changed, 579 insertions(+), 100 deletions(-) diff --git a/examples/models/CNN_genome.py b/examples/models/CNN_genome.py index 147f8c9e..4f851706 100644 --- a/examples/models/CNN_genome.py +++ b/examples/models/CNN_genome.py @@ -26,7 +26,7 @@ def double_conv(in_channels, out_channels): class UNet(nn.Module): - def __init__(self, n_class, n_channels_in=6): + def __init__(self, out_features=16, n_channels_in=6): super().__init__() self.dconv_down1 = double_conv(n_channels_in, 15) @@ -46,7 +46,7 @@ def __init__(self, n_class, n_channels_in=6): self.dconv_up2 = double_conv(22 + 33, 22) self.dconv_up1 = double_conv(15 + 22, 15) - self.conv_last = nn.Conv2d(15, n_class, 1) + self.conv_last = nn.Conv1d(15, out_features, 1) def forward(self, x): diff --git a/examples/sbox_run_expt.ipynb b/examples/sbox_run_expt.ipynb index 071a68d7..92cb1746 100644 --- a/examples/sbox_run_expt.ipynb +++ b/examples/sbox_run_expt.ipynb @@ -109,7 +109,7 @@ { "data": { "text/plain": [ - "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" + "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" ] }, "execution_count": 2, @@ -207,7 +207,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -219,20 +219,20 @@ "config_encode = parser.parse_args(argstr_encode.split())\n", "config_encode = populate_defaults(config_encode)\n", "\n", - "# config = config_camelyon\n", - "config = config_encode\n" + "config = config_camelyon\n", + "# config = config_encode\n" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Dataset: encode-tfbs\n", + "Dataset: camelyon17\n", "Algorithm: ERM\n", "Root dir: data\n", "Split scheme: official\n", @@ -244,26 +244,26 @@ "Uniform over groups: False\n", "Distinct groups: None\n", "N groups per batch: 2\n", - "Batch size: 64\n", + "Batch size: 32\n", "Eval loader: standard\n", - "Model: leopard\n", + "Model: densenet121\n", "Model kwargs: {'pretrained': False}\n", - "Train transform: None\n", - "Eval transform: None\n", - "Target resolution: None\n", + "Train transform: image_base\n", + "Eval transform: image_base\n", + "Target resolution: (224, 224)\n", "Resize scale: None\n", "Max token length: None\n", "Loss function: cross_entropy\n", - "Groupby fields: ['celltype']\n", + "Groupby fields: ['hospital']\n", "Group dro step size: None\n", - "Coral penalty weight: None\n", - "Irm lambda: None\n", + "Coral penalty weight: 0.1\n", + "Irm lambda: 1.0\n", "Irm penalty anneal iters: None\n", "Algo log metric: accuracy\n", "Val metric: acc_avg\n", "Val metric decreasing: False\n", "N epochs: 5\n", - "Optimizer: Adam\n", + "Optimizer: SGD\n", "Lr: 0.001\n", "Weight decay: 0.01\n", "Max grad norm: None\n", @@ -287,20 +287,7 @@ "Use wandb: False\n", "Progress bar: False\n", "Resume: False\n", - "\n", - "chr3 5.088634967803955\n", - "chr4 9.974164009094238\n", - "chr5 15.149149894714355\n", - "chr6 19.728455066680908\n", - "chr7 23.769655466079712\n", - "chr10 29.31521511077881\n", - "chr12 32.78225326538086\n", - "chr13 35.67028570175171\n", - "chr14 46.721638441085815\n", - "chr15 92.16564106941223\n", - "chr16 96.26218318939209\n", - "chr17 114.85105729103088\n", - "chr18 116.09504199028015\n" + "\n" ] } ], @@ -845,18 +832,62 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'algorithm' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0malgorithm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mNameError\u001b[0m: name 'algorithm' is not defined" + ] + } + ], "source": [ - "config" + "algorithm.model" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train data...\n", + " hospital = 0: n = 53425\n", + " hospital = 1: n = 0\n", + " hospital = 2: n = 0\n", + " hospital = 3: n = 116959\n", + " hospital = 4: n = 132052\n", + "Validation (ID) data...\n", + " hospital = 0: n = 6011\n", + " hospital = 1: n = 0\n", + " hospital = 2: n = 0\n", + " hospital = 3: n = 12879\n", + " hospital = 4: n = 14670\n", + "Test data...\n", + " hospital = 0: n = 0\n", + " hospital = 1: n = 0\n", + " hospital = 2: n = 85054\n", + " hospital = 3: n = 0\n", + " hospital = 4: n = 0\n", + "Validation (OOD) data...\n", + " hospital = 0: n = 0\n", + " hospital = 1: n = 34904\n", + " hospital = 2: n = 0\n", + " hospital = 3: n = 0\n", + " hospital = 4: n = 0\n", + "Dout: 2\n" + ] + } + ], "source": [ "# config = config_encode\n", "\n", @@ -933,73 +964,521 @@ }, { "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Available objects for config:\n", - " AliasManager\n", - " DisplayFormatter\n", - " HistoryManager\n", - " IPCompleter\n", - " IPKernelApp\n", - " LoggingMagics\n", - " MagicsManager\n", - " OSMagics\n", - " PrefilterManager\n", - " ScriptMagics\n", - " StoreMagics\n", - " ZMQInteractiveShell\n" - ] - } - ], - "source": [ - "config" - ] - }, - { - "cell_type": "code", - "execution_count": 13, + "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "device(type='cuda', index=0)" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "config.device" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "ename": "NameError", - "evalue": "name 'np' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfull_dataset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_input\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;31m#\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mNameError\u001b[0m: name 'np' is not defined" - ] - } - ], - "source": [ - "np.array(full_dataset.get_input(0)).shape\n", - "#" + "DenseNet(\n", + " (features): Sequential(\n", + " (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n", + " (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu0): ReLU(inplace=True)\n", + " (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n", + " (denseblock1): _DenseBlock(\n", + " (denselayer1): _DenseLayer(\n", + " (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer2): _DenseLayer(\n", + " (norm1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(96, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer3): _DenseLayer(\n", + " (norm1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer4): _DenseLayer(\n", + " (norm1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(160, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer5): _DenseLayer(\n", + " (norm1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer6): _DenseLayer(\n", + " (norm1): BatchNorm2d(224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(224, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " )\n", + " (transition1): _Transition(\n", + " (norm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): ReLU(inplace=True)\n", + " (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n", + " )\n", + " (denseblock2): _DenseBlock(\n", + " (denselayer1): _DenseLayer(\n", + " (norm1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer2): _DenseLayer(\n", + " (norm1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(160, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer3): _DenseLayer(\n", + " (norm1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer4): _DenseLayer(\n", + " (norm1): BatchNorm2d(224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(224, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer5): _DenseLayer(\n", + " (norm1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer6): _DenseLayer(\n", + " (norm1): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(288, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer7): _DenseLayer(\n", + " (norm1): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer8): _DenseLayer(\n", + " (norm1): BatchNorm2d(352, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(352, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer9): _DenseLayer(\n", + " (norm1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(384, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer10): _DenseLayer(\n", + " (norm1): BatchNorm2d(416, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(416, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer11): _DenseLayer(\n", + " (norm1): BatchNorm2d(448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(448, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer12): _DenseLayer(\n", + " (norm1): BatchNorm2d(480, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(480, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " )\n", + " (transition2): _Transition(\n", + " (norm): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): ReLU(inplace=True)\n", + " (conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n", + " )\n", + " (denseblock3): _DenseBlock(\n", + " (denselayer1): _DenseLayer(\n", + " (norm1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer2): _DenseLayer(\n", + " (norm1): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(288, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer3): _DenseLayer(\n", + " (norm1): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer4): _DenseLayer(\n", + " (norm1): BatchNorm2d(352, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(352, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer5): _DenseLayer(\n", + " (norm1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(384, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer6): _DenseLayer(\n", + " (norm1): BatchNorm2d(416, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(416, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer7): _DenseLayer(\n", + " (norm1): BatchNorm2d(448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(448, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer8): _DenseLayer(\n", + " (norm1): BatchNorm2d(480, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(480, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer9): _DenseLayer(\n", + " (norm1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer10): _DenseLayer(\n", + " (norm1): BatchNorm2d(544, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(544, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer11): _DenseLayer(\n", + " (norm1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer12): _DenseLayer(\n", + " (norm1): BatchNorm2d(608, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer13): _DenseLayer(\n", + " (norm1): BatchNorm2d(640, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(640, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer14): _DenseLayer(\n", + " (norm1): BatchNorm2d(672, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(672, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer15): _DenseLayer(\n", + " (norm1): BatchNorm2d(704, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(704, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer16): _DenseLayer(\n", + " (norm1): BatchNorm2d(736, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(736, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer17): _DenseLayer(\n", + " (norm1): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(768, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer18): _DenseLayer(\n", + " (norm1): BatchNorm2d(800, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(800, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer19): _DenseLayer(\n", + " (norm1): BatchNorm2d(832, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(832, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer20): _DenseLayer(\n", + " (norm1): BatchNorm2d(864, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(864, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer21): _DenseLayer(\n", + " (norm1): BatchNorm2d(896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer22): _DenseLayer(\n", + " (norm1): BatchNorm2d(928, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(928, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer23): _DenseLayer(\n", + " (norm1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(960, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer24): _DenseLayer(\n", + " (norm1): BatchNorm2d(992, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(992, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " )\n", + " (transition3): _Transition(\n", + " (norm): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu): ReLU(inplace=True)\n", + " (conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n", + " )\n", + " (denseblock4): _DenseBlock(\n", + " (denselayer1): _DenseLayer(\n", + " (norm1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer2): _DenseLayer(\n", + " (norm1): BatchNorm2d(544, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(544, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer3): _DenseLayer(\n", + " (norm1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer4): _DenseLayer(\n", + " (norm1): BatchNorm2d(608, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer5): _DenseLayer(\n", + " (norm1): BatchNorm2d(640, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(640, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer6): _DenseLayer(\n", + " (norm1): BatchNorm2d(672, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(672, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer7): _DenseLayer(\n", + " (norm1): BatchNorm2d(704, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(704, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer8): _DenseLayer(\n", + " (norm1): BatchNorm2d(736, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(736, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer9): _DenseLayer(\n", + " (norm1): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(768, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer10): _DenseLayer(\n", + " (norm1): BatchNorm2d(800, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(800, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer11): _DenseLayer(\n", + " (norm1): BatchNorm2d(832, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(832, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer12): _DenseLayer(\n", + " (norm1): BatchNorm2d(864, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(864, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer13): _DenseLayer(\n", + " (norm1): BatchNorm2d(896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer14): _DenseLayer(\n", + " (norm1): BatchNorm2d(928, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(928, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer15): _DenseLayer(\n", + " (norm1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(960, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " (denselayer16): _DenseLayer(\n", + " (norm1): BatchNorm2d(992, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu1): ReLU(inplace=True)\n", + " (conv1): Conv2d(992, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (relu2): ReLU(inplace=True)\n", + " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " )\n", + " )\n", + " (norm5): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " (classifier): Linear(in_features=1024, out_features=2, bias=True)\n", + ")" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "algorithm.model" ] }, { diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index b5657597..23b70014 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -5,7 +5,7 @@ import pyBigWig from wilds.datasets.wilds_dataset import WILDSDataset from wilds.common.grouper import CombinatorialGrouper -from wilds.common.metrics.all_metrics import Accuracy +from wilds.common.metrics.all_metrics import Accuracy, MultiTaskAccuracy all_chrom_names = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX'] @@ -135,7 +135,7 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): dataset=self, groupby_fields=['celltype']) - self._metric = Accuracy() + self._metric = MultiTaskAccuracy() super().__init__(root_dir, download, split_scheme) From 88c0d51e82e0eb85f19734579437b4888735a53c Mon Sep 17 00:00:00 2001 From: aikanor Date: Thu, 11 Mar 2021 07:57:37 -0800 Subject: [PATCH 075/244] final integration w/eval code 2/3 --- examples/configs/data_loader.py | 2 +- examples/configs/datasets.py | 7 +- examples/configs/model.py | 4 +- examples/models/CNN_genome.py | 24 +- examples/sbox_run_expt.ipynb | 1960 +++++--------------------- wilds/common/metrics/all_metrics.py | 30 +- wilds/common/metrics/metric.py | 1 + wilds/datasets/encodetfbs_dataset.py | 15 +- 8 files changed, 433 insertions(+), 1610 deletions(-) diff --git a/examples/configs/data_loader.py b/examples/configs/data_loader.py index 38741464..c00b1b64 100644 --- a/examples/configs/data_loader.py +++ b/examples/configs/data_loader.py @@ -1,6 +1,6 @@ loader_defaults = { 'loader_kwargs':{ - 'num_workers': 4, + 'num_workers': 1, 'pin_memory': True, }, 'n_groups_per_batch': 4, diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index 7731dbd0..4ea7cff2 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -111,19 +111,18 @@ 'model_kwargs': {'pretrained': False}, 'train_transform': None, 'eval_transform': None, - 'loss_function': 'cross_entropy', + 'loss_function': 'multitask_bce', 'groupby_fields': ['celltype'], 'val_metric': 'acc_avg', 'val_metric_decreasing': False, - 'optimizer': 'Adam', - # 'optimizer_kwargs': { }, + 'optimizer': 'Adam', 'scheduler': None, 'batch_size': 64, 'lr': 0.001, 'weight_decay': 0.01, 'n_epochs': 5, 'n_groups_per_batch': 2, - 'algo_log_metric': 'accuracy', + 'algo_log_metric': 'multitask_avgprec', # 'irm_lambda': 1.0, # 'coral_penalty_weight': 0.1, }, diff --git a/examples/configs/model.py b/examples/configs/model.py index f4eb8779..a4df713b 100644 --- a/examples/configs/model.py +++ b/examples/configs/model.py @@ -37,5 +37,7 @@ 'target_resolution': (224, 224), }, 'logistic_regression': {}, - 'leopard': {}, + 'leopard': { + 'optimizer': 'Adam' + }, } diff --git a/examples/models/CNN_genome.py b/examples/models/CNN_genome.py index 4f851706..7397eeb2 100644 --- a/examples/models/CNN_genome.py +++ b/examples/models/CNN_genome.py @@ -37,16 +37,23 @@ def __init__(self, out_features=16, n_channels_in=6): self.dconv_down6 = double_conv(73, 109) self.maxpool = nn.MaxPool1d(2) + # self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) self.conv_middle = single_conv(109, 109) + self.upsamp_6 = nn.ConvTranspose1d(109, 109, 2, stride=2) self.dconv_up5 = double_conv(73 + 109, 73) + self.upsamp_5 = nn.ConvTranspose1d(73, 73, 2, stride=2) self.dconv_up4 = double_conv(49 + 73, 49) + self.upsamp_4 = nn.ConvTranspose1d(49, 49, 2, stride=2) self.dconv_up3 = double_conv(33 + 49, 33) + self.upsamp_3 = nn.ConvTranspose1d(33, 33, 2, stride=2) self.dconv_up2 = double_conv(22 + 33, 22) + self.upsamp_2 = nn.ConvTranspose1d(22, 22, 2, stride=2) self.dconv_up1 = double_conv(15 + 22, 15) + self.upsamp_1 = nn.ConvTranspose1d(15, 15, 2, stride=2) - self.conv_last = nn.Conv1d(15, out_features, 1) + self.conv_last = nn.Conv1d(15, 1, 200, stride=50, padding=0) def forward(self, x): @@ -72,27 +79,28 @@ def forward(self, x): # Encoder finished. - x = self.upsample(conv6) # (input_size / 16) x 109 + x = self.upsamp_6(conv6) # (input_size / 16) x 109 x = torch.cat([x, conv5], dim=1) # (input_size / 16) x (109 + 73) x = self.dconv_up5(x) # (input_size / 16) x 73 - x = self.upsample(x) # (input_size / 8) x 73 + x = self.upsamp_5(x) # (input_size / 8) x 73 x = torch.cat([x, conv4], dim=1) # (input_size / 8) x (73 + 49) x = self.dconv_up4(x) # (input_size / 8) x 49 - x = self.upsample(x) # (input_size / 4) x 49 + x = self.upsamp_4(x) # (input_size / 4) x 49 x = torch.cat([x, conv3], dim=1) # (input_size / 4) x (49 + 33) x = self.dconv_up3(x) # (input_size / 4) x 33 - x = self.upsample(x) # (input_size / 2) x 33 + x = self.upsamp_3(x) # (input_size / 2) x 33 x = torch.cat([x, conv2], dim=1) # (input_size / 2) x (33 + 22) x = self.dconv_up2(x) # (input_size / 2) x 22 - x = self.upsample(x) # (input_size) x 22 + x = self.upsamp_2(x) # (input_size) x 22 x = torch.cat([x, conv1], dim=1) # (input_size) x (22 + 15) x = self.dconv_up1(x) # (input_size) x 15 - out = self.conv_last(x) + # middle 128 bits + out = self.conv_last(x)[:, :, 64:192] - return out + return torch.squeeze(out) diff --git a/examples/sbox_run_expt.ipynb b/examples/sbox_run_expt.ipynb index 92cb1746..5040aeb0 100644 --- a/examples/sbox_run_expt.ipynb +++ b/examples/sbox_run_expt.ipynb @@ -11,18 +11,14 @@ }, { "cell_type": "code", - "execution_count": 123, + "execution_count": 27, "metadata": {}, "outputs": [ { - "ename": "ModuleNotFoundError", - "evalue": "No module named 'psutil'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpsutil\u001b[0m\u001b[0;34m;\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpsutil\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mProcess\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgetpid\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmemory_info\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrss\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0;36m1024\u001b[0m \u001b[0;34m**\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'psutil'" + "name": "stdout", + "output_type": "stream", + "text": [ + "4860.4765625\n" ] } ], @@ -109,7 +105,7 @@ { "data": { "text/plain": [ - "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" + "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" ] }, "execution_count": 2, @@ -120,7 +116,7 @@ "source": [ "''' set default hyperparams in default_hyperparams.py '''\n", "parser = argparse.ArgumentParser()\n", - "CombinatorialGrouper\n", + "\n", "# Required arguments\n", "parser.add_argument('-d', '--dataset', choices=supported.datasets, required=True)\n", "parser.add_argument('--algorithm', required=True, choices=supported.algorithms)\n", @@ -207,11 +203,12 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "argstr_camelyon = \"--dataset camelyon17 --algorithm ERM --root_dir data\"\n", + "# argstr_camelyon = \"--dataset civilcomments --algorithm ERM --root_dir data\"\n", "config_camelyon = parser.parse_args(argstr_camelyon.split())\n", "config_camelyon = populate_defaults(config_camelyon)\n", "\n", @@ -220,7 +217,31 @@ "config_encode = populate_defaults(config_encode)\n", "\n", "config = config_camelyon\n", - "# config = config_encode\n" + "config = config_encode\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "argstr_camelyon = \"--dataset camelyon17 --algorithm ERM --root_dir data\"\n", + "# argstr_camelyon = \"--dataset civilcomments --algorithm ERM --root_dir data\"\n", + "config_camelyon = parser.parse_args(argstr_camelyon.split())\n", + "\n", + "argstr_encode = \"--dataset encode-tfbs --algorithm ERM --root_dir data\"\n", + "config_encode = parser.parse_args(argstr_encode.split())\n", + "config_encode" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "config.optimizer_kwargs = {}" ] }, { @@ -232,42 +253,42 @@ "name": "stdout", "output_type": "stream", "text": [ - "Dataset: camelyon17\n", + "Dataset: encode-tfbs\n", "Algorithm: ERM\n", "Root dir: data\n", "Split scheme: official\n", "Dataset kwargs: {}\n", "Download: False\n", "Frac: 1.0\n", - "Loader kwargs: {'num_workers': 4, 'pin_memory': True}\n", + "Loader kwargs: {'num_workers': 1, 'pin_memory': True}\n", "Train loader: standard\n", "Uniform over groups: False\n", "Distinct groups: None\n", "N groups per batch: 2\n", - "Batch size: 32\n", + "Batch size: 64\n", "Eval loader: standard\n", - "Model: densenet121\n", + "Model: leopard\n", "Model kwargs: {'pretrained': False}\n", - "Train transform: image_base\n", - "Eval transform: image_base\n", - "Target resolution: (224, 224)\n", + "Train transform: None\n", + "Eval transform: None\n", + "Target resolution: None\n", "Resize scale: None\n", "Max token length: None\n", - "Loss function: cross_entropy\n", - "Groupby fields: ['hospital']\n", + "Loss function: multitask_bce\n", + "Groupby fields: ['celltype']\n", "Group dro step size: None\n", - "Coral penalty weight: 0.1\n", - "Irm lambda: 1.0\n", + "Coral penalty weight: None\n", + "Irm lambda: None\n", "Irm penalty anneal iters: None\n", - "Algo log metric: accuracy\n", + "Algo log metric: multitask_avgprec\n", "Val metric: acc_avg\n", "Val metric decreasing: False\n", "N epochs: 5\n", - "Optimizer: SGD\n", + "Optimizer: Adam\n", "Lr: 0.001\n", "Weight decay: 0.01\n", "Max grad norm: None\n", - "Optimizer kwargs: {'momentum': 0.9}\n", + "Optimizer kwargs: {}\n", "Scheduler: None\n", "Scheduler kwargs: {}\n", "Scheduler metric split: val\n", @@ -287,7 +308,10 @@ "Use wandb: False\n", "Progress bar: False\n", "Resume: False\n", - "\n" + "\n", + "chr3 2.9614717960357666\n", + "chr2 6.587897777557373\n", + "chr1 10.29332971572876\n" ] } ], @@ -337,13 +361,20 @@ }, { "cell_type": "code", - "execution_count": 5, - "metadata": { - "jupyter": { - "source_hidden": true + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" } - }, - "outputs": [], + ], "source": [ "import copy\n", "full_dataset_camelyon17 = copy.deepcopy(full_dataset)\n", @@ -361,12 +392,11 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 6, "metadata": { "collapsed": true, "jupyter": { - "outputs_hidden": true, - "source_hidden": true + "outputs_hidden": true } }, "outputs": [ @@ -374,36 +404,36 @@ "name": "stdout", "output_type": "stream", "text": [ - "chr3 3.0872416496276855\n", - "chr4 6.014077425003052\n", - "chr5 8.789116859436035\n", - "chr6 11.409600496292114\n", - "chr7 13.844907283782959\n", - "chr10 15.919893264770508\n", - "chr12 17.969276189804077\n", - "chr13 19.71941637992859\n", - "chr14 21.34366464614868\n", - "chr15 22.900768995285034\n", - "chr16 24.27766728401184\n", - "chr17 25.519333600997925\n", - "chr18 26.714667797088623\n", - "chr19 27.614336490631104\n", - "chr20 28.57899522781372\n", - "chr22 29.353068113327026\n", - "chrX 31.731130599975586\n", - "chr2 35.449124813079834\n", - "chr9 37.5920934677124\n", - "chr11 39.65406608581543\n", - "chr1 43.44736051559448\n", - "chr8 45.68234419822693\n", - "chr21 46.41120982170105\n", - "H1-hESC 46.41424226760864\n", - "HCT116 46.41492676734924\n", - "HeLa-S3 46.41563010215759\n", - "HepG2 46.41687893867493\n", - "K562 46.41777992248535\n", - "A549 46.41860294342041\n", - "GM12878 46.41955780982971\n" + "chr3 3.0055365562438965\n", + "chr4 5.905960321426392\n", + "chr5 8.651455879211426\n", + "chr6 11.250766038894653\n", + "chr7 13.660939931869507\n", + "chr10 15.713522672653198\n", + "chr12 17.740623474121094\n", + "chr13 19.478207111358643\n", + "chr14 21.088634252548218\n", + "chr15 22.625713348388672\n", + "chr16 23.987269639968872\n", + "chr17 25.21428894996643\n", + "chr18 26.394341230392456\n", + "chr19 27.28497076034546\n", + "chr20 28.235496282577515\n", + "chr22 28.999913692474365\n", + "chrX 31.338406085968018\n", + "chr2 35.00527381896973\n", + "chr9 37.12277841567993\n", + "chr11 39.157737016677856\n", + "chr1 42.89226841926575\n", + "chr8 45.092690229415894\n", + "chr21 45.81230306625366\n", + "H1-hESC 45.81402635574341\n", + "HCT116 45.814292192459106\n", + "HeLa-S3 45.814526081085205\n", + "HepG2 45.814810276031494\n", + "K562 45.815062522888184\n", + "A549 45.81636619567871\n", + "GM12878 45.81674289703369\n" ] } ], @@ -492,52 +522,8 @@ }, { "cell_type": "code", - "execution_count": 325, - "metadata": { - "jupyter": { - "source_hidden": true - } - }, - "outputs": [], - "source": [ - "def get_random_label_vec(\n", - " metadata_df, seed_chr, seed_celltype, seed_start, output_size=128\n", - "):\n", - " \"\"\"\n", - " Given a coordinate in a celltype, gets the labels of \n", - " the `output_size` 200bp bins from that coordinate onward. \n", - " \"\"\"\n", - " itime = time.time()\n", - " \n", - " # Extract regions from this chromosome in this celltype, to get a window of labels from\n", - " # print(time.time() - itime)\n", - " # chr_msk = np.array(metadata_df['chr']) == seed_region['chr']\n", - " # print(time.time() - itime)\n", - " # ct_msk = np.array(metadata_df['celltype']) == seed_region['celltype']\n", - " # mdf = metadata_df[chr_msk & ct_msk]\n", - " seq_size = output_size*50\n", - " mdf = metadata_df.loc[\n", - " (metadata_df['chr'] == seed_chr) & \n", - " (metadata_df['celltype'] == seed_celltype) & \n", - " (metadata_df['start'] >= seed_start) & \n", - " (metadata_df['stop'] < seed_start+seq_size)\n", - " ]\n", - " print(time.time() - itime)\n", - "\n", - " # Get labels\n", - " y_label_vec = np.zeros(output_size)\n", - " y_label_vec[(mdf['start'] - seed_start) // 50] = mdf['y']\n", - " return mdf, y_label_vec" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": { - "jupyter": { - "source_hidden": true - } - }, + "execution_count": 7, + "metadata": {}, "outputs": [], "source": [ "train_regions_mask = np.isin(_metadata_df['chr'], _train_chroms)\n", @@ -577,6 +563,42 @@ "_metadata_fields = ['chr', 'celltype']" ] }, + { + "cell_type": "code", + "execution_count": 325, + "metadata": {}, + "outputs": [], + "source": [ + "def get_random_label_vec(\n", + " metadata_df, seed_chr, seed_celltype, seed_start, output_size=128\n", + "):\n", + " \"\"\"\n", + " Given a coordinate in a celltype, gets the labels of \n", + " the `output_size` 200bp bins from that coordinate onward. \n", + " \"\"\"\n", + " itime = time.time()\n", + " \n", + " # Extract regions from this chromosome in this celltype, to get a window of labels from\n", + " # print(time.time() - itime)\n", + " # chr_msk = np.array(metadata_df['chr']) == seed_region['chr']\n", + " # print(time.time() - itime)\n", + " # ct_msk = np.array(metadata_df['celltype']) == seed_region['celltype']\n", + " # mdf = metadata_df[chr_msk & ct_msk]\n", + " seq_size = output_size*50\n", + " mdf = metadata_df.loc[\n", + " (metadata_df['chr'] == seed_chr) & \n", + " (metadata_df['celltype'] == seed_celltype) & \n", + " (metadata_df['start'] >= seed_start) & \n", + " (metadata_df['stop'] < seed_start+seq_size)\n", + " ]\n", + " print(time.time() - itime)\n", + "\n", + " # Get labels\n", + " y_label_vec = np.zeros(output_size)\n", + " y_label_vec[(mdf['start'] - seed_start) // 50] = mdf['y']\n", + " return mdf, y_label_vec" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -587,11 +609,7 @@ { "cell_type": "code", "execution_count": 24, - "metadata": { - "jupyter": { - "source_hidden": true - } - }, + "metadata": {}, "outputs": [], "source": [ "import os, time\n", @@ -830,27 +848,6 @@ "# Initialize algorithm" ] }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "ename": "NameError", - "evalue": "name 'algorithm' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0malgorithm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;31mNameError\u001b[0m: name 'algorithm' is not defined" - ] - } - ], - "source": [ - "algorithm.model" - ] - }, { "cell_type": "code", "execution_count": 6, @@ -861,30 +858,38 @@ "output_type": "stream", "text": [ "Train data...\n", - " hospital = 0: n = 53425\n", - " hospital = 1: n = 0\n", - " hospital = 2: n = 0\n", - " hospital = 3: n = 116959\n", - " hospital = 4: n = 132052\n", + " celltype = H1-hESC: n = 5314\n", + " celltype = HCT116: n = 4759\n", + " celltype = HeLa-S3: n = 4635\n", + " celltype = HepG2: n = 4459\n", + " celltype = K562: n = 5169\n", + " celltype = A549: n = 0\n", + " celltype = GM12878: n = 0\n", "Validation (ID) data...\n", - " hospital = 0: n = 6011\n", - " hospital = 1: n = 0\n", - " hospital = 2: n = 0\n", - " hospital = 3: n = 12879\n", - " hospital = 4: n = 14670\n", + " celltype = H1-hESC: n = 6872\n", + " celltype = HCT116: n = 6315\n", + " celltype = HeLa-S3: n = 4219\n", + " celltype = HepG2: n = 8356\n", + " celltype = K562: n = 6538\n", + " celltype = A549: n = 0\n", + " celltype = GM12878: n = 0\n", "Test data...\n", - " hospital = 0: n = 0\n", - " hospital = 1: n = 0\n", - " hospital = 2: n = 85054\n", - " hospital = 3: n = 0\n", - " hospital = 4: n = 0\n", + " celltype = H1-hESC: n = 0\n", + " celltype = HCT116: n = 0\n", + " celltype = HeLa-S3: n = 0\n", + " celltype = HepG2: n = 0\n", + " celltype = K562: n = 0\n", + " celltype = A549: n = 0\n", + " celltype = GM12878: n = 4487\n", "Validation (OOD) data...\n", - " hospital = 0: n = 0\n", - " hospital = 1: n = 34904\n", - " hospital = 2: n = 0\n", - " hospital = 3: n = 0\n", - " hospital = 4: n = 0\n", - "Dout: 2\n" + " celltype = H1-hESC: n = 0\n", + " celltype = HCT116: n = 0\n", + " celltype = HeLa-S3: n = 0\n", + " celltype = HepG2: n = 0\n", + " celltype = K562: n = 0\n", + " celltype = A549: n = 6728\n", + " celltype = GM12878: n = 0\n", + "Dout: 128\n" ] } ], @@ -966,630 +971,237 @@ "cell_type": "code", "execution_count": 7, "metadata": {}, + "outputs": [], + "source": [ + "for batch in datasets['train']['loader']:\n", + " x, y_true, metadata = batch\n", + " break\n", + "# x = torch.transpose(x, 1, 2)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "d = algorithm.process_batch(batch)\n", + "# algorithm.loss.compute" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "DenseNet(\n", - " (features): Sequential(\n", - " (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n", - " (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu0): ReLU(inplace=True)\n", - " (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n", - " (denseblock1): _DenseBlock(\n", - " (denselayer1): _DenseLayer(\n", - " (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer2): _DenseLayer(\n", - " (norm1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(96, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer3): _DenseLayer(\n", - " (norm1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer4): _DenseLayer(\n", - " (norm1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(160, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer5): _DenseLayer(\n", - " (norm1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer6): _DenseLayer(\n", - " (norm1): BatchNorm2d(224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(224, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " )\n", - " (transition1): _Transition(\n", - " (norm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu): ReLU(inplace=True)\n", - " (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n", - " )\n", - " (denseblock2): _DenseBlock(\n", - " (denselayer1): _DenseLayer(\n", - " (norm1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer2): _DenseLayer(\n", - " (norm1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(160, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer3): _DenseLayer(\n", - " (norm1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer4): _DenseLayer(\n", - " (norm1): BatchNorm2d(224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(224, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer5): _DenseLayer(\n", - " (norm1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer6): _DenseLayer(\n", - " (norm1): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(288, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer7): _DenseLayer(\n", - " (norm1): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer8): _DenseLayer(\n", - " (norm1): BatchNorm2d(352, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(352, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer9): _DenseLayer(\n", - " (norm1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(384, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer10): _DenseLayer(\n", - " (norm1): BatchNorm2d(416, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(416, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer11): _DenseLayer(\n", - " (norm1): BatchNorm2d(448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(448, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer12): _DenseLayer(\n", - " (norm1): BatchNorm2d(480, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(480, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " )\n", - " (transition2): _Transition(\n", - " (norm): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu): ReLU(inplace=True)\n", - " (conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n", - " )\n", - " (denseblock3): _DenseBlock(\n", - " (denselayer1): _DenseLayer(\n", - " (norm1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer2): _DenseLayer(\n", - " (norm1): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(288, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer3): _DenseLayer(\n", - " (norm1): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer4): _DenseLayer(\n", - " (norm1): BatchNorm2d(352, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(352, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer5): _DenseLayer(\n", - " (norm1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(384, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer6): _DenseLayer(\n", - " (norm1): BatchNorm2d(416, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(416, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer7): _DenseLayer(\n", - " (norm1): BatchNorm2d(448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(448, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer8): _DenseLayer(\n", - " (norm1): BatchNorm2d(480, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(480, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer9): _DenseLayer(\n", - " (norm1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer10): _DenseLayer(\n", - " (norm1): BatchNorm2d(544, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(544, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer11): _DenseLayer(\n", - " (norm1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer12): _DenseLayer(\n", - " (norm1): BatchNorm2d(608, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer13): _DenseLayer(\n", - " (norm1): BatchNorm2d(640, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(640, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer14): _DenseLayer(\n", - " (norm1): BatchNorm2d(672, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(672, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer15): _DenseLayer(\n", - " (norm1): BatchNorm2d(704, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(704, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer16): _DenseLayer(\n", - " (norm1): BatchNorm2d(736, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(736, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer17): _DenseLayer(\n", - " (norm1): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(768, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer18): _DenseLayer(\n", - " (norm1): BatchNorm2d(800, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(800, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer19): _DenseLayer(\n", - " (norm1): BatchNorm2d(832, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(832, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer20): _DenseLayer(\n", - " (norm1): BatchNorm2d(864, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(864, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer21): _DenseLayer(\n", - " (norm1): BatchNorm2d(896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer22): _DenseLayer(\n", - " (norm1): BatchNorm2d(928, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(928, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer23): _DenseLayer(\n", - " (norm1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(960, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer24): _DenseLayer(\n", - " (norm1): BatchNorm2d(992, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(992, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " )\n", - " (transition3): _Transition(\n", - " (norm): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu): ReLU(inplace=True)\n", - " (conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n", - " )\n", - " (denseblock4): _DenseBlock(\n", - " (denselayer1): _DenseLayer(\n", - " (norm1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer2): _DenseLayer(\n", - " (norm1): BatchNorm2d(544, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(544, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer3): _DenseLayer(\n", - " (norm1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer4): _DenseLayer(\n", - " (norm1): BatchNorm2d(608, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer5): _DenseLayer(\n", - " (norm1): BatchNorm2d(640, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(640, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer6): _DenseLayer(\n", - " (norm1): BatchNorm2d(672, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(672, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer7): _DenseLayer(\n", - " (norm1): BatchNorm2d(704, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(704, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer8): _DenseLayer(\n", - " (norm1): BatchNorm2d(736, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(736, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer9): _DenseLayer(\n", - " (norm1): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(768, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer10): _DenseLayer(\n", - " (norm1): BatchNorm2d(800, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(800, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer11): _DenseLayer(\n", - " (norm1): BatchNorm2d(832, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(832, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer12): _DenseLayer(\n", - " (norm1): BatchNorm2d(864, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(864, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer13): _DenseLayer(\n", - " (norm1): BatchNorm2d(896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer14): _DenseLayer(\n", - " (norm1): BatchNorm2d(928, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(928, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer15): _DenseLayer(\n", - " (norm1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(960, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer16): _DenseLayer(\n", - " (norm1): BatchNorm2d(992, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(992, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " )\n", - " (norm5): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - " (classifier): Linear(in_features=1024, out_features=2, bias=True)\n", - ")" + "tensor(0.7212, device='cuda:0', grad_fn=)" ] }, - "execution_count": 7, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "algorithm.model" + "a = algorithm.loss.compute(d['y_pred'], d['y_true'], return_dict=False)\n", + "a" ] }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 10, "metadata": {}, "outputs": [ { "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
chrstartstopcelltypesplit
39413chr21000320010009600A5493
39414chr2100032000100038400A5493
39415chr2100102400100108800A5493
39416chr2100172800100179200A5493
39417chr2100230400100236800A5493
..................
495287chr3999680010003200K5620
495288chr39997440099980800K5620
495289chr39998080099987200K5620
495290chr39998720099993600K5620
495291chr399993600100000000K5620
\n", + "

67851 rows × 5 columns

\n", + "
" + ], "text/plain": [ - "" + " chr start stop celltype split\n", + "39413 chr2 10003200 10009600 A549 3\n", + "39414 chr2 100032000 100038400 A549 3\n", + "39415 chr2 100102400 100108800 A549 3\n", + "39416 chr2 100172800 100179200 A549 3\n", + "39417 chr2 100230400 100236800 A549 3\n", + "... ... ... ... ... ...\n", + "495287 chr3 9996800 10003200 K562 0\n", + "495288 chr3 99974400 99980800 K562 0\n", + "495289 chr3 99980800 99987200 K562 0\n", + "495290 chr3 99987200 99993600 K562 0\n", + "495291 chr3 99993600 100000000 K562 0\n", + "\n", + "[67851 rows x 5 columns]" ] }, - "execution_count": 29, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "# algorithm.device\n", - "full_dataset\n", - "# datasets['train']['loader']" + "#np.unique(full_dataset._metadata_df['split'], return_counts=True)\n", + "full_dataset._metadata_df" ] }, { "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [], - "source": [ - "for batch in datasets['train']['loader']:\n", - " x, y_true, metadata = batch\n", - " break" - ] - }, - { - "cell_type": "code", - "execution_count": 43, + "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor([1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0,\n", - " 0, 1, 1, 1, 0, 0, 0, 0])" + "(array([0. , 0.5, 1. ], dtype=float32), array([7422683, 1007200, 255045]))" ] }, - "execution_count": 43, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "y_true" + "np.unique(full_dataset.y_array, return_counts=True)" ] }, { "cell_type": "code", - "execution_count": 30, - "metadata": { - "collapsed": true, - "jupyter": { - "outputs_hidden": true - } - }, + "execution_count": 26, + "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor([[ 0.1406, -0.0628],\n", - " [ 0.0534, 0.0359],\n", - " [-0.0174, -0.0097],\n", - " [-0.0571, -0.2381],\n", - " [ 0.1590, -0.0559],\n", - " [ 0.1254, -0.0139],\n", - " [-0.0423, 0.0439],\n", - " [ 0.1621, 0.0730],\n", - " [ 0.0554, 0.0796],\n", - " [-0.0532, 0.0667],\n", - " [-0.1927, -0.0387],\n", - " [ 0.1352, -0.0385],\n", - " [-0.1320, 0.0140],\n", - " [-0.0531, -0.1171],\n", - " [-0.0378, -0.0134],\n", - " [ 0.1047, 0.0298],\n", - " [ 0.0355, -0.0497],\n", - " [ 0.1065, -0.0218],\n", - " [-0.1883, 0.1298],\n", - " [ 0.0699, -0.0875],\n", - " [-0.1233, 0.1793],\n", - " [ 0.0151, 0.0708],\n", - " [-0.0973, -0.0033],\n", - " [ 0.1027, -0.2456],\n", - " [ 0.0433, -0.0441],\n", - " [ 0.1013, -0.1020],\n", - " [ 0.1309, -0.0051],\n", - " [ 0.0028, -0.0558],\n", - " [ 0.0635, 0.0575],\n", - " [-0.0066, 0.0666],\n", - " [-0.0076, -0.0375],\n", - " [ 0.1336, 0.0024]], device='cuda:0', grad_fn=)" + "0.8546625832706961" ] }, - "execution_count": 30, + "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "# datasets['train']['dataset'].size()\n", - "algorithm.model(x.to(algorithm.device))" + "7422683/8684928" ] }, { @@ -1601,9 +1213,40 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Epoch [0]:\n", + "\n", + "Train:\n", + "torch.Size([8192]) torch.Size([8192]) torch.Size([64, 128]) torch.Size([64, 128])\n", + "torch.Size([]) torch.Size([8192]) torch.Size([64, 128]) torch.Size([64, 128])\n" + ] + }, + { + "ename": "AssertionError", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 26\u001b[0;31m train(\n\u001b[0m\u001b[1;32m 27\u001b[0m \u001b[0malgorithm\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0malgorithm\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0mdatasets\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdatasets\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/wilds/examples/train.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(algorithm, datasets, general_logger, config, epoch_offset, best_val_metric)\u001b[0m\n\u001b[1;32m 88\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 89\u001b[0m \u001b[0;31m# First run training\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 90\u001b[0;31m \u001b[0mrun_epoch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0malgorithm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdatasets\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'train'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgeneral_logger\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepoch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 91\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[0;31m# Then run val\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/wilds/examples/train.py\u001b[0m in \u001b[0;36mrun_epoch\u001b[0;34m(algorithm, dataset, general_logger, epoch, config, train)\u001b[0m\n\u001b[1;32m 41\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0miterator\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 43\u001b[0;31m \u001b[0mbatch_results\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0malgorithm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 44\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[0mbatch_results\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0malgorithm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mevaluate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/wilds/examples/algorithms/single_model_algorithm.py\u001b[0m in \u001b[0;36mupdate\u001b[0;34m(self, batch)\u001b[0m\n\u001b[1;32m 103\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_update\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 104\u001b[0m \u001b[0;31m# log results\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 105\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate_log\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 106\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msanitize_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 107\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/wilds/examples/algorithms/group_algorithm.py\u001b[0m in \u001b[0;36mupdate_log\u001b[0;34m(self, results)\u001b[0m\n\u001b[1;32m 47\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mm\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlogged_metrics\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mno_group_logging\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 49\u001b[0;31m group_metrics, group_counts, worst_group_metric = m.compute_group_wise(\n\u001b[0m\u001b[1;32m 50\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'y_pred'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'y_true'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/wilds/wilds/common/metrics/metric.py\u001b[0m in \u001b[0;36mcompute_group_wise\u001b[0;34m(self, y_pred, y_true, g, n_groups, return_dict)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mresults\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mdict\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mDictionary\u001b[0m \u001b[0mof\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 113\u001b[0m \"\"\"\n\u001b[0;32m--> 114\u001b[0;31m \u001b[0mgroup_metrics\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgroup_counts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mworst_group_metric\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compute_group_wise\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_groups\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 115\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mreturn_dict\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/wilds/wilds/common/metrics/metric.py\u001b[0m in \u001b[0;36m_compute_group_wise\u001b[0;34m(self, y_pred, y_true, g, n_groups)\u001b[0m\n\u001b[1;32m 234\u001b[0m \u001b[0mflattened_g\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mg\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mindices\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 235\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mflattened_metrics\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mflattened_g\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_pred\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 236\u001b[0;31m \u001b[0mgroup_metrics\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgroup_counts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mavg_over_groups\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mflattened_metrics\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mflattened_g\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_groups\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 237\u001b[0m \u001b[0mworst_group_metric\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mworst\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgroup_metrics\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mgroup_counts\u001b[0m\u001b[0;34m>\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 238\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mgroup_metrics\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgroup_counts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mworst_group_metric\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/wilds/wilds/common/utils.py\u001b[0m in \u001b[0;36mavg_over_groups\u001b[0;34m(v, g, n_groups)\u001b[0m\n\u001b[1;32m 84\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0mg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 85\u001b[0m \u001b[0mdevice\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 86\u001b[0;31m \u001b[0;32massert\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0mg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 87\u001b[0m \u001b[0mgroup_count\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_counts\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_groups\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 88\u001b[0m \u001b[0mgroup_avgs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch_scatter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscatter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msrc\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindex\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mn_groups\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduce\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'mean'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mAssertionError\u001b[0m: " + ] + } + ], "source": [ "if not config.eval_only:\n", " ## Load saved results if resuming\n", @@ -1681,867 +1324,6 @@ "outputs": [], "source": [] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 33, - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "\n", - "class Beagle(nn.Module):\n", - " \"\"\"\n", - " Neural net models over genomic sequence.\n", - " Input:\n", - " - sequence_length: int (default 1000) \n", - " - Shape: (N, 5, sequence_length, 1) with batch size N.\n", - " \n", - " Output:\n", - " - prediction (Tensor): float torch tensor of shape (N, )\n", - " \n", - " TODO: Finish docstring.\n", - " \"\"\"\n", - " def __init__(self):\n", - " \"\"\"\n", - " Parameters\n", - " ----------\n", - " sequence_length : int\n", - " n_genomic_features : int\n", - " \"\"\"\n", - " super(Beagle, self).__init__()\n", - "\n", - " self.dropout = 0.3\n", - " self.num_cell_types = 1\n", - " self.conv1 = nn.Conv2d(5, 50, (19, 1), stride = (1, 1), padding=(9,0))\n", - " self.conv2 = nn.Conv2d(50, 50, (11, 1), stride = (1, 1), padding = (5,0))\n", - " self.conv3 = nn.Conv2d(50, 50, (7, 1), stride = (1, 1), padding = (4,0))\n", - " self.bn1 = nn.BatchNorm2d(50)\n", - " self.bn2 = nn.BatchNorm2d(50)\n", - " self.bn3 = nn.BatchNorm2d(50)\n", - " self.maxpool1 = nn.MaxPool2d((3, 1))\n", - " self.maxpool2 = nn.MaxPool2d((4, 1))\n", - " self.maxpool3 = nn.MaxPool2d((4, 1))\n", - "\n", - " self.fc1 = nn.Linear(4200, 1000)\n", - " self.bn4 = nn.BatchNorm1d(1000)\n", - "\n", - " self.fc2 = nn.Linear(1000, 1000)\n", - " self.bn5 = nn.BatchNorm1d(1000)\n", - "\n", - " self.fc3 = nn.Linear(1000, self.num_cell_types)\n", - "\n", - " def forward(self, s):\n", - " s = s.permute(0, 2, 1).contiguous() # batch_size x 5 x 1000\n", - " s = s.view(-1, 5, 1000, 1) # batch_size x 5 x 1000 x 1 [5 channels]\n", - " s = self.maxpool1(F.relu(self.bn1(self.conv1(s)))) # batch_size x 300 x 333 x 1\n", - " s = self.maxpool2(F.relu(self.bn2(self.conv2(s)))) # batch_size x 200 x 83 x 1\n", - " s = self.maxpool3(F.relu(self.bn3(self.conv3(s)))) # batch_size x 200 x 21 x 1\n", - " s = s.view(-1, 4200)\n", - " conv_out = s\n", - "\n", - " s = F.dropout(F.relu(self.bn4(self.fc1(s))), p=self.dropout, training=self.training) # batch_size x 1000\n", - " s = F.dropout(F.relu(self.bn5(self.fc2(s))), p=self.dropout, training=self.training) # batch_size x 1000\n", - " \n", - " s = self.fc3(s)\n", - "\n", - " return s, conv_out" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [], - "source": [ - "import math\n", - "import numpy as np\n", - "import torch\n", - "import torch.nn as nn\n", - "import torch.nn.functional as F\n", - "\n", - "\n", - "def double_conv(in_channels, out_channels): \n", - " return nn.Sequential(\n", - " nn.Conv1d(in_channels, out_channels, 7, padding=2), \n", - " nn.BatchNorm1d(out_channels), \n", - " nn.ReLU(inplace=True),\n", - " nn.Conv1d(out_channels, out_channels, 7, padding=3), \n", - " nn.BatchNorm1d(out_channels), \n", - " nn.ReLU(inplace=True)\n", - " )\n", - "\n", - "\n", - "class UNet(nn.Module):\n", - "\n", - " def __init__(self, n_class):\n", - " super().__init__()\n", - " \n", - " self.dconv_down1 = double_conv(6, 15)\n", - " self.dconv_down2 = double_conv(15, 22)\n", - " self.dconv_down3 = double_conv(22, 33)\n", - " self.dconv_down4 = double_conv(33, 49)\n", - " self.dconv_down5 = double_conv(49, 73)\n", - " self.dconv_down6 = double_conv(73, 109)\n", - "\n", - " self.maxpool = nn.MaxPool1d(2)\n", - " self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) \n", - " \n", - " self.dconv_up5 = double_conv(73 + 109, 73)\n", - " self.dconv_up4 = double_conv(49 + 73, 49)\n", - " self.dconv_up3 = double_conv(33 + 49, 33)\n", - " self.dconv_up2 = double_conv(22 + 33, 22)\n", - " self.dconv_up1 = double_conv(15 + 22, 15)\n", - " \n", - " self.conv_last = nn.Conv2d(15, n_class, 1)\n", - " \n", - " \n", - " def forward(self, x):\n", - " conv1 = self.dconv_down1(x)\n", - " x = self.maxpool(conv1)\n", - "\n", - " conv2 = self.dconv_down2(x)\n", - " x = self.maxpool(conv2)\n", - " \n", - " conv3 = self.dconv_down3(x)\n", - " x = self.maxpool(conv3)\n", - " \n", - " conv4 = self.dconv_down4(x)\n", - " x = self.maxpool(conv4)\n", - " \n", - " conv5 = self.dconv_down5(x)\n", - " x = self.maxpool(conv5)\n", - " \n", - " x = self.dconv_down6(x)\n", - " \n", - " x = self.upsample(x) \n", - " x = torch.cat([x, conv5], dim=1)\n", - " \n", - " x = self.dconv_up5(x)\n", - " x = self.upsample(x) \n", - " x = torch.cat([x, conv4], dim=1)\n", - " \n", - " x = self.dconv_up4(x)\n", - " x = self.upsample(x) \n", - " x = torch.cat([x, conv3], dim=1)\n", - " \n", - " x = self.dconv_up3(x)\n", - " x = self.upsample(x) \n", - " x = torch.cat([x, conv2], dim=1) \n", - "\n", - " x = self.dconv_up2(x)\n", - " x = self.upsample(x) \n", - " x = torch.cat([x, conv1], dim=1) \n", - " \n", - " x = self.dconv_up1(x)\n", - " \n", - " out = self.conv_last(x)\n", - " \n", - " return out" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": { - "collapsed": true, - "jupyter": { - "outputs_hidden": true - } - }, - "outputs": [ - { - "data": { - "text/plain": [ - "UNet(\n", - " (dconv_down1): Sequential(\n", - " (0): Conv1d(6, 15, kernel_size=(7,), stride=(1,), padding=(2,))\n", - " (1): BatchNorm1d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (2): ReLU(inplace=True)\n", - " (3): Conv1d(15, 15, kernel_size=(7,), stride=(1,), padding=(3,))\n", - " (4): BatchNorm1d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (5): ReLU(inplace=True)\n", - " )\n", - " (dconv_down2): Sequential(\n", - " (0): Conv1d(15, 22, kernel_size=(7,), stride=(1,), padding=(2,))\n", - " (1): BatchNorm1d(22, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (2): ReLU(inplace=True)\n", - " (3): Conv1d(22, 22, kernel_size=(7,), stride=(1,), padding=(3,))\n", - " (4): BatchNorm1d(22, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (5): ReLU(inplace=True)\n", - " )\n", - " (dconv_down3): Sequential(\n", - " (0): Conv1d(22, 33, kernel_size=(7,), stride=(1,), padding=(2,))\n", - " (1): BatchNorm1d(33, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (2): ReLU(inplace=True)\n", - " (3): Conv1d(33, 33, kernel_size=(7,), stride=(1,), padding=(3,))\n", - " (4): BatchNorm1d(33, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (5): ReLU(inplace=True)\n", - " )\n", - " (dconv_down4): Sequential(\n", - " (0): Conv1d(33, 49, kernel_size=(7,), stride=(1,), padding=(2,))\n", - " (1): BatchNorm1d(49, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (2): ReLU(inplace=True)\n", - " (3): Conv1d(49, 49, kernel_size=(7,), stride=(1,), padding=(3,))\n", - " (4): BatchNorm1d(49, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (5): ReLU(inplace=True)\n", - " )\n", - " (dconv_down5): Sequential(\n", - " (0): Conv1d(49, 73, kernel_size=(7,), stride=(1,), padding=(2,))\n", - " (1): BatchNorm1d(73, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (2): ReLU(inplace=True)\n", - " (3): Conv1d(73, 73, kernel_size=(7,), stride=(1,), padding=(3,))\n", - " (4): BatchNorm1d(73, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (5): ReLU(inplace=True)\n", - " )\n", - " (dconv_down6): Sequential(\n", - " (0): Conv1d(73, 109, kernel_size=(7,), stride=(1,), padding=(2,))\n", - " (1): BatchNorm1d(109, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (2): ReLU(inplace=True)\n", - " (3): Conv1d(109, 109, kernel_size=(7,), stride=(1,), padding=(3,))\n", - " (4): BatchNorm1d(109, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (5): ReLU(inplace=True)\n", - " )\n", - " (maxpool): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", - " (upsample): Upsample(scale_factor=2.0, mode=bilinear)\n", - " (dconv_up5): Sequential(\n", - " (0): Conv1d(182, 73, kernel_size=(7,), stride=(1,), padding=(2,))\n", - " (1): BatchNorm1d(73, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (2): ReLU(inplace=True)\n", - " (3): Conv1d(73, 73, kernel_size=(7,), stride=(1,), padding=(3,))\n", - " (4): BatchNorm1d(73, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (5): ReLU(inplace=True)\n", - " )\n", - " (dconv_up4): Sequential(\n", - " (0): Conv1d(122, 49, kernel_size=(7,), stride=(1,), padding=(2,))\n", - " (1): BatchNorm1d(49, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (2): ReLU(inplace=True)\n", - " (3): Conv1d(49, 49, kernel_size=(7,), stride=(1,), padding=(3,))\n", - " (4): BatchNorm1d(49, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (5): ReLU(inplace=True)\n", - " )\n", - " (dconv_up3): Sequential(\n", - " (0): Conv1d(82, 33, kernel_size=(7,), stride=(1,), padding=(2,))\n", - " (1): BatchNorm1d(33, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (2): ReLU(inplace=True)\n", - " (3): Conv1d(33, 33, kernel_size=(7,), stride=(1,), padding=(3,))\n", - " (4): BatchNorm1d(33, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (5): ReLU(inplace=True)\n", - " )\n", - " (dconv_up2): Sequential(\n", - " (0): Conv1d(55, 22, kernel_size=(7,), stride=(1,), padding=(2,))\n", - " (1): BatchNorm1d(22, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (2): ReLU(inplace=True)\n", - " (3): Conv1d(22, 22, kernel_size=(7,), stride=(1,), padding=(3,))\n", - " (4): BatchNorm1d(22, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (5): ReLU(inplace=True)\n", - " )\n", - " (dconv_up1): Sequential(\n", - " (0): Conv1d(37, 15, kernel_size=(7,), stride=(1,), padding=(2,))\n", - " (1): BatchNorm1d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (2): ReLU(inplace=True)\n", - " (3): Conv1d(15, 15, kernel_size=(7,), stride=(1,), padding=(3,))\n", - " (4): BatchNorm1d(15, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (5): ReLU(inplace=True)\n", - " )\n", - " (conv_last): Conv2d(15, 2, kernel_size=(1, 1), stride=(1, 1))\n", - ")" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "model = UNet(2)\n", - "model" - ] - }, - { - "cell_type": "code", - "execution_count": 101, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "485773" - ] - }, - "execution_count": 101, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "def count_parameters(model):\n", - " return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", - "\n", - "lst = [(x[0], x[1].numel()) for x in model.named_parameters()]\n", - "#np.sum([x[1] for x in lst])\n", - "count_parameters(model)" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "6955906" - ] - }, - "execution_count": 25, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "count_parameters(algorithm.model)" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": {}, - "outputs": [], - "source": [ - "lst = [(x[0], x[1].numel()) for x in algorithm.model.named_parameters()]" - ] - }, - { - "cell_type": "code", - "execution_count": 39, - "metadata": { - "collapsed": true, - "jupyter": { - "outputs_hidden": true - } - }, - "outputs": [ - { - "data": { - "text/plain": [ - "DenseNet(\n", - " (features): Sequential(\n", - " (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n", - " (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu0): ReLU(inplace=True)\n", - " (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n", - " (denseblock1): _DenseBlock(\n", - " (denselayer1): _DenseLayer(\n", - " (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer2): _DenseLayer(\n", - " (norm1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(96, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer3): _DenseLayer(\n", - " (norm1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer4): _DenseLayer(\n", - " (norm1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(160, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer5): _DenseLayer(\n", - " (norm1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer6): _DenseLayer(\n", - " (norm1): BatchNorm2d(224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(224, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " )\n", - " (transition1): _Transition(\n", - " (norm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu): ReLU(inplace=True)\n", - " (conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n", - " )\n", - " (denseblock2): _DenseBlock(\n", - " (denselayer1): _DenseLayer(\n", - " (norm1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer2): _DenseLayer(\n", - " (norm1): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(160, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer3): _DenseLayer(\n", - " (norm1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(192, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer4): _DenseLayer(\n", - " (norm1): BatchNorm2d(224, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(224, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer5): _DenseLayer(\n", - " (norm1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer6): _DenseLayer(\n", - " (norm1): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(288, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer7): _DenseLayer(\n", - " (norm1): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer8): _DenseLayer(\n", - " (norm1): BatchNorm2d(352, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(352, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer9): _DenseLayer(\n", - " (norm1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(384, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer10): _DenseLayer(\n", - " (norm1): BatchNorm2d(416, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(416, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer11): _DenseLayer(\n", - " (norm1): BatchNorm2d(448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(448, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer12): _DenseLayer(\n", - " (norm1): BatchNorm2d(480, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(480, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " )\n", - " (transition2): _Transition(\n", - " (norm): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu): ReLU(inplace=True)\n", - " (conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n", - " )\n", - " (denseblock3): _DenseBlock(\n", - " (denselayer1): _DenseLayer(\n", - " (norm1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer2): _DenseLayer(\n", - " (norm1): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(288, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer3): _DenseLayer(\n", - " (norm1): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(320, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer4): _DenseLayer(\n", - " (norm1): BatchNorm2d(352, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(352, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer5): _DenseLayer(\n", - " (norm1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(384, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer6): _DenseLayer(\n", - " (norm1): BatchNorm2d(416, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(416, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer7): _DenseLayer(\n", - " (norm1): BatchNorm2d(448, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(448, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer8): _DenseLayer(\n", - " (norm1): BatchNorm2d(480, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(480, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer9): _DenseLayer(\n", - " (norm1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer10): _DenseLayer(\n", - " (norm1): BatchNorm2d(544, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(544, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer11): _DenseLayer(\n", - " (norm1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer12): _DenseLayer(\n", - " (norm1): BatchNorm2d(608, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer13): _DenseLayer(\n", - " (norm1): BatchNorm2d(640, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(640, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer14): _DenseLayer(\n", - " (norm1): BatchNorm2d(672, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(672, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer15): _DenseLayer(\n", - " (norm1): BatchNorm2d(704, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(704, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer16): _DenseLayer(\n", - " (norm1): BatchNorm2d(736, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(736, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer17): _DenseLayer(\n", - " (norm1): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(768, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer18): _DenseLayer(\n", - " (norm1): BatchNorm2d(800, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(800, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer19): _DenseLayer(\n", - " (norm1): BatchNorm2d(832, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(832, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer20): _DenseLayer(\n", - " (norm1): BatchNorm2d(864, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(864, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer21): _DenseLayer(\n", - " (norm1): BatchNorm2d(896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer22): _DenseLayer(\n", - " (norm1): BatchNorm2d(928, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(928, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer23): _DenseLayer(\n", - " (norm1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(960, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer24): _DenseLayer(\n", - " (norm1): BatchNorm2d(992, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(992, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " )\n", - " (transition3): _Transition(\n", - " (norm): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu): ReLU(inplace=True)\n", - " (conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (pool): AvgPool2d(kernel_size=2, stride=2, padding=0)\n", - " )\n", - " (denseblock4): _DenseBlock(\n", - " (denselayer1): _DenseLayer(\n", - " (norm1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer2): _DenseLayer(\n", - " (norm1): BatchNorm2d(544, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(544, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer3): _DenseLayer(\n", - " (norm1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(576, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer4): _DenseLayer(\n", - " (norm1): BatchNorm2d(608, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(608, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer5): _DenseLayer(\n", - " (norm1): BatchNorm2d(640, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(640, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer6): _DenseLayer(\n", - " (norm1): BatchNorm2d(672, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(672, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer7): _DenseLayer(\n", - " (norm1): BatchNorm2d(704, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(704, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer8): _DenseLayer(\n", - " (norm1): BatchNorm2d(736, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(736, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer9): _DenseLayer(\n", - " (norm1): BatchNorm2d(768, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(768, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer10): _DenseLayer(\n", - " (norm1): BatchNorm2d(800, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(800, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer11): _DenseLayer(\n", - " (norm1): BatchNorm2d(832, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(832, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer12): _DenseLayer(\n", - " (norm1): BatchNorm2d(864, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(864, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer13): _DenseLayer(\n", - " (norm1): BatchNorm2d(896, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(896, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer14): _DenseLayer(\n", - " (norm1): BatchNorm2d(928, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(928, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer15): _DenseLayer(\n", - " (norm1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(960, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " (denselayer16): _DenseLayer(\n", - " (norm1): BatchNorm2d(992, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu1): ReLU(inplace=True)\n", - " (conv1): Conv2d(992, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " (relu2): ReLU(inplace=True)\n", - " (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " )\n", - " )\n", - " (norm5): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", - " )\n", - " (classifier): Linear(in_features=1024, out_features=2, bias=True)\n", - ")" - ] - }, - "execution_count": 39, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "algorithm.model" - ] - }, { "cell_type": "code", "execution_count": null, diff --git a/wilds/common/metrics/all_metrics.py b/wilds/common/metrics/all_metrics.py index a491bedd..7a3d6975 100644 --- a/wilds/common/metrics/all_metrics.py +++ b/wilds/common/metrics/all_metrics.py @@ -58,6 +58,30 @@ def _compute_flattened(self, flattened_y_pred, flattened_y_true): def worst(self, metrics): return minimum(metrics) +class MultiTaskAveragePrecision(MultiTaskMetric): + def __init__(self, prediction_fn=logits_to_binary_pred, name=None, average='macro'): + self.prediction_fn = prediction_fn + if name is None: + name = f'avgprec' + if average is not None: + name+=f'-{average}' + self.average = average + super().__init__(name=name) + + def _compute_flattened(self, flattened_y_pred, flattened_y_true): + if self.prediction_fn is not None: + flattened_y_pred = self.prediction_fn(flattened_y_pred) + score = sklearn.metrics.average_precision_score( + np.array(flattened_y_true.squeeze().detach().cpu().numpy() > 0), + flattened_y_pred.squeeze().detach().cpu().numpy(), + average=self.average + ) + return torch.tensor(score).to(flattened_y_pred.device) + + def worst(self, metrics): + return minimum(metrics) + + class Recall(Metric): def __init__(self, prediction_fn=None, name=None, average='binary'): self.prediction_fn = prediction_fn @@ -90,7 +114,11 @@ def __init__(self, prediction_fn=logits_to_pred, name=None, average='macro'): def _compute(self, y_pred, y_true): if self.prediction_fn is not None: y_pred = self.prediction_fn(y_pred) - score = sklearn.metrics.average_precision_score(y_true, y_pred, average=self.average, labels=torch.unique(y_true)) + score = sklearn.metrics.average_precision_score( + np.array(y_true.squeeze().detach().cpu().numpy() > 0), + y_pred.squeeze().detach().cpu().numpy(), + average=self.average + ) return torch.tensor(score) def worst(self, metrics): diff --git a/wilds/common/metrics/metric.py b/wilds/common/metrics/metric.py index 9c4372b0..2bc8237e 100644 --- a/wilds/common/metrics/metric.py +++ b/wilds/common/metrics/metric.py @@ -232,6 +232,7 @@ def _compute(self, y_pred, y_true): def _compute_group_wise(self, y_pred, y_true, g, n_groups): flattened_metrics, indices = self.compute_flattened(y_pred, y_true, return_dict=False) flattened_g = g[indices] + print(flattened_metrics.shape, flattened_g.shape, y_pred.shape, y_true.shape) group_metrics, group_counts = avg_over_groups(flattened_metrics, flattened_g, n_groups) worst_group_metric = self.worst(group_metrics[group_counts>0]) return group_metrics, group_counts, worst_group_metric diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index 23b70014..588b9fce 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -15,7 +15,7 @@ class EncodeTFBSDataset(WILDSDataset): This is a subset of the dataset from the ENCODE-DREAM in vivo Transcription Factor Binding Site Prediction Challenge. Input (x): - 1000-base-pair regions of sequence with a quantified chromatin accessibility readout. + 12800-base-pair regions of sequence with a quantified chromatin accessibility readout. Label (y): y is binary. It is 1 if the central 200bp region is bound by the transcription factor MAX, and 0 otherwise. @@ -36,9 +36,9 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): self._y_size = 128 # self._n_classes = 2 - self._train_chroms = ['chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr10', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX'] - self._val_chroms = ['chr2', 'chr9', 'chr11'] - self._test_chroms = ['chr1', 'chr8', 'chr21'] + self._train_chroms = ['chr3']#, 'chr4', 'chr5', 'chr6', 'chr7', 'chr10', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX'] + self._val_chroms = ['chr2']#, 'chr9', 'chr11'] + self._test_chroms = ['chr1']#, 'chr8', 'chr21'] self._transcription_factor = 'MAX' self._train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562'] self._val_celltype = ['A549'] @@ -165,15 +165,18 @@ def get_input(self, idx, window_size=12800): (4) Window_size, the length of sequence returned (centered on the 6400bp region in (3)) """ this_metadata = self._metadata_df.iloc[idx, :] + chrom = this_metadata['chr'] interval_start = this_metadata['start'] - int(window_size/4) interval_end = interval_start + window_size #this_metadata['stop'] seq_this = self._seq_bp[this_metadata['chr']][interval_start:interval_end] dnase_bw = self._dnase_allcelltypes[this_metadata['celltype']] dnase_this = dnase_bw.values(chrom, interval_start, interval_end, numpy=True) + # print("{}:{}-{}".format(chrom, interval_start, interval_end)) dnase_avg = self._dnase_allcelltypes['avg'].values(chrom, interval_start, interval_end, numpy=True) return torch.tensor(np.column_stack( - [np.nan_to_num(seq_this), np.nan_to_num(dnase_this), np.nan_to_num(dnase_avg)] - )) + [np.nan_to_num(seq_this), + np.nan_to_num(dnase_this), np.nan_to_num(dnase_avg)] + ).T) def eval(self, y_pred, y_true, metadata): return self.standard_group_eval( From 6b2f22c165652b480e9161ea0aa58e93d627258a Mon Sep 17 00:00:00 2001 From: aikanor Date: Thu, 18 Mar 2021 19:06:08 -0700 Subject: [PATCH 076/244] integration besides eval --- examples/sbox_run_expt.ipynb | 465 +++++++++------------------- wilds/common/metrics/all_metrics.py | 11 +- wilds/common/metrics/metric.py | 2 +- 3 files changed, 162 insertions(+), 316 deletions(-) diff --git a/examples/sbox_run_expt.ipynb b/examples/sbox_run_expt.ipynb index 5040aeb0..86525331 100644 --- a/examples/sbox_run_expt.ipynb +++ b/examples/sbox_run_expt.ipynb @@ -11,14 +11,21 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "4860.4765625\n" + "396.69921875\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:root:The OGB package is out of date. Your version is 1.2.4, while the latest version is 1.3.0.\n" ] } ], @@ -69,7 +76,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "WARNING:root:The OGB package is out of date. Your version is 1.2.4, while the latest version is 1.2.6.\n" + "WARNING:root:The WILDS package is out of date. Your version is 1.0.0, while the latest version is 1.1.0.\n", + "WARNING:root:The OGB package is out of date. Your version is 1.2.4, while the latest version is 1.3.0.\n" ] } ], @@ -105,7 +113,7 @@ { "data": { "text/plain": [ - "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" + "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" ] }, "execution_count": 2, @@ -208,23 +216,38 @@ "outputs": [], "source": [ "argstr_camelyon = \"--dataset camelyon17 --algorithm ERM --root_dir data\"\n", - "# argstr_camelyon = \"--dataset civilcomments --algorithm ERM --root_dir data\"\n", "config_camelyon = parser.parse_args(argstr_camelyon.split())\n", "config_camelyon = populate_defaults(config_camelyon)\n", "\n", + "argstr_bdd100k = \"--dataset bdd100k --algorithm ERM --root_dir data\"\n", + "config_bdd100k = parser.parse_args(argstr_bdd100k.split())\n", + "config_bdd100k = populate_defaults(config_bdd100k)\n", + "\n", "argstr_encode = \"--dataset encode-tfbs --algorithm ERM --root_dir data\"\n", "config_encode = parser.parse_args(argstr_encode.split())\n", "config_encode = populate_defaults(config_encode)\n", "\n", "config = config_camelyon\n", - "config = config_encode\n" + "config = config_encode\n", + "config = config_bdd100k\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "Namespace(algo_log_metric=None, algorithm='ERM', batch_size=None, coral_penalty_weight=None, dataset='encode-tfbs', dataset_kwargs={}, device=0, distinct_groups=None, download=False, eval_epoch=None, eval_loader='standard', eval_only=False, eval_splits=[], eval_transform=None, evaluate_all_splits=True, frac=1.0, group_dro_step_size=None, groupby_fields=None, irm_lambda=None, irm_penalty_anneal_iters=None, loader_kwargs={'num_workers': 1, 'pin_memory': True}, log_dir='./logs', log_every=50, loss_function=None, lr=None, max_grad_norm=None, max_token_length=None, model=None, model_kwargs={'pretrained': False}, n_epochs=None, n_groups_per_batch=None, no_group_logging=None, optimizer=None, optimizer_kwargs={'momentum': 0.9}, progress_bar=False, resize_scale=None, resume=False, root_dir='data', save_best=True, save_last=True, save_step=None, scheduler=None, scheduler_kwargs={}, scheduler_metric_name=None, scheduler_metric_split='val', seed=0, split_scheme=None, target_resolution=None, train_loader=None, train_transform=None, uniform_over_groups=None, use_wandb=False, val_metric=None, val_metric_decreasing=None, weight_decay=None)" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "argstr_camelyon = \"--dataset camelyon17 --algorithm ERM --root_dir data\"\n", "# argstr_camelyon = \"--dataset civilcomments --algorithm ERM --root_dir data\"\n", @@ -237,23 +260,34 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "Namespace(algo_log_metric='multitask_accuracy', algorithm='ERM', batch_size=32, coral_penalty_weight=None, dataset='bdd100k', dataset_kwargs={}, device=0, distinct_groups=None, download=False, eval_epoch=None, eval_loader='standard', eval_only=False, eval_splits=[], eval_transform='image_base', evaluate_all_splits=True, frac=1.0, group_dro_step_size=None, groupby_fields=None, irm_lambda=None, irm_penalty_anneal_iters=None, loader_kwargs={'num_workers': 1, 'pin_memory': True}, log_dir='./logs', log_every=50, loss_function='multitask_bce', lr=0.001, max_grad_norm=None, max_token_length=None, model='resnet50', model_kwargs={'pretrained': False}, n_epochs=10, n_groups_per_batch=4, no_group_logging=True, optimizer='SGD', optimizer_kwargs={'momentum': 0.9}, progress_bar=False, resize_scale=None, resume=False, root_dir='data', save_best=True, save_last=True, save_step=None, scheduler=None, scheduler_kwargs={}, scheduler_metric_name=None, scheduler_metric_split='val', seed=0, split_scheme='official', target_resolution=(224, 224), train_loader='standard', train_transform='image_base', uniform_over_groups=False, use_wandb=False, val_metric='acc_all', val_metric_decreasing=False, weight_decay=0.0001)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "config.optimizer_kwargs = {}" + "config#.optimizer_kwargs = {}" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Dataset: encode-tfbs\n", + "Dataset: bdd100k\n", "Algorithm: ERM\n", "Root dir: data\n", "Split scheme: official\n", @@ -264,31 +298,31 @@ "Train loader: standard\n", "Uniform over groups: False\n", "Distinct groups: None\n", - "N groups per batch: 2\n", - "Batch size: 64\n", + "N groups per batch: 4\n", + "Batch size: 32\n", "Eval loader: standard\n", - "Model: leopard\n", + "Model: resnet50\n", "Model kwargs: {'pretrained': False}\n", - "Train transform: None\n", - "Eval transform: None\n", - "Target resolution: None\n", + "Train transform: image_base\n", + "Eval transform: image_base\n", + "Target resolution: (224, 224)\n", "Resize scale: None\n", "Max token length: None\n", "Loss function: multitask_bce\n", - "Groupby fields: ['celltype']\n", + "Groupby fields: None\n", "Group dro step size: None\n", "Coral penalty weight: None\n", "Irm lambda: None\n", "Irm penalty anneal iters: None\n", - "Algo log metric: multitask_avgprec\n", - "Val metric: acc_avg\n", + "Algo log metric: multitask_accuracy\n", + "Val metric: acc_all\n", "Val metric decreasing: False\n", - "N epochs: 5\n", - "Optimizer: Adam\n", + "N epochs: 10\n", + "Optimizer: SGD\n", "Lr: 0.001\n", - "Weight decay: 0.01\n", + "Weight decay: 0.0001\n", "Max grad norm: None\n", - "Optimizer kwargs: {}\n", + "Optimizer kwargs: {'momentum': 0.9}\n", "Scheduler: None\n", "Scheduler kwargs: {}\n", "Scheduler metric split: val\n", @@ -304,14 +338,11 @@ "Save step: None\n", "Save best: True\n", "Save last: True\n", - "No group logging: False\n", + "No group logging: True\n", "Use wandb: False\n", "Progress bar: False\n", "Resume: False\n", - "\n", - "chr3 2.9614717960357666\n", - "chr2 6.587897777557373\n", - "chr1 10.29332971572876\n" + "\n" ] } ], @@ -359,30 +390,6 @@ " dataset=full_dataset)" ] }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import copy\n", - "full_dataset_camelyon17 = copy.deepcopy(full_dataset)\n", - "\n", - "# supported.datasets[config_encode.dataset]\n", - "# print(config_camelyon.train_transform, config_encode.train_transform)\n" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -392,11 +399,12 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": { "collapsed": true, "jupyter": { - "outputs_hidden": true + "outputs_hidden": true, + "source_hidden": true } }, "outputs": [ @@ -404,36 +412,33 @@ "name": "stdout", "output_type": "stream", "text": [ - "chr3 3.0055365562438965\n", - "chr4 5.905960321426392\n", - "chr5 8.651455879211426\n", - "chr6 11.250766038894653\n", - "chr7 13.660939931869507\n", - "chr10 15.713522672653198\n", - "chr12 17.740623474121094\n", - "chr13 19.478207111358643\n", - "chr14 21.088634252548218\n", - "chr15 22.625713348388672\n", - "chr16 23.987269639968872\n", - "chr17 25.21428894996643\n", - "chr18 26.394341230392456\n", - "chr19 27.28497076034546\n", - "chr20 28.235496282577515\n", - "chr22 28.999913692474365\n", - "chrX 31.338406085968018\n", - "chr2 35.00527381896973\n", - "chr9 37.12277841567993\n", - "chr11 39.157737016677856\n", - "chr1 42.89226841926575\n", - "chr8 45.092690229415894\n", - "chr21 45.81230306625366\n", - "H1-hESC 45.81402635574341\n", - "HCT116 45.814292192459106\n", - "HeLa-S3 45.814526081085205\n", - "HepG2 45.814810276031494\n", - "K562 45.815062522888184\n", - "A549 45.81636619567871\n", - "GM12878 45.81674289703369\n" + "chr3 3.0039219856262207\n", + "chr4 5.89985990524292\n", + "chr5 8.640583038330078\n", + "chr6 11.237342596054077\n", + "chr7 13.666043519973755\n", + "chr10 15.858035326004028\n", + "chr12 17.94972252845764\n", + "chr13 19.689449071884155\n", + "chr14 21.30842876434326\n", + "chr15 22.856398582458496\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[0m_seq_bp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 57\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mchrom\u001b[0m \u001b[0;32min\u001b[0m \u001b[0m_all_chroms\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 58\u001b[0;31m \u001b[0m_seq_bp\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mchrom\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mseq_arr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mchrom\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 59\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mchrom\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mitime\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/site-packages/numpy/lib/npyio.py\u001b[0m in \u001b[0;36m__getitem__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 252\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mmagic\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mformat\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mMAGIC_PREFIX\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 253\u001b[0m \u001b[0mbytes\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzip\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 254\u001b[0;31m return format.read_array(bytes,\n\u001b[0m\u001b[1;32m 255\u001b[0m \u001b[0mallow_pickle\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mallow_pickle\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 256\u001b[0m pickle_kwargs=self.pickle_kwargs)\n", + "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/site-packages/numpy/lib/format.py\u001b[0m in \u001b[0;36mread_array\u001b[0;34m(fp, allow_pickle, pickle_kwargs)\u001b[0m\n\u001b[1;32m 773\u001b[0m \u001b[0mread_count\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmax_read_count\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcount\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 774\u001b[0m \u001b[0mread_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mread_count\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitemsize\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 775\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_read_bytes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mread_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"array data\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 776\u001b[0m array[i:i+read_count] = numpy.frombuffer(data, dtype=dtype,\n\u001b[1;32m 777\u001b[0m count=read_count)\n", + "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/site-packages/numpy/lib/format.py\u001b[0m in \u001b[0;36m_read_bytes\u001b[0;34m(fp, size, error_template)\u001b[0m\n\u001b[1;32m 902\u001b[0m \u001b[0;31m# done about that. note that regular files can't be non-blocking\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 903\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 904\u001b[0;31m \u001b[0mr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msize\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 905\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mr\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 906\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mr\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0msize\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/zipfile.py\u001b[0m in \u001b[0;36mread\u001b[0;34m(self, n)\u001b[0m\n\u001b[1;32m 938\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_offset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 939\u001b[0m \u001b[0;32mwhile\u001b[0m \u001b[0mn\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m0\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_eof\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 940\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_read1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 941\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mn\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 942\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_readbuffer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/zipfile.py\u001b[0m in \u001b[0;36m_read1\u001b[0;34m(self, n)\u001b[0m\n\u001b[1;32m 1028\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_left\u001b[0m \u001b[0;34m<=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1029\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_eof\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1030\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_update_crc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1031\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1032\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/zipfile.py\u001b[0m in \u001b[0;36m_update_crc\u001b[0;34m(self, newdata)\u001b[0m\n\u001b[1;32m 953\u001b[0m \u001b[0;31m# No need to compute the CRC if we don't have a reference value\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 954\u001b[0m \u001b[0;32mreturn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 955\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_running_crc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcrc32\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnewdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_running_crc\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 956\u001b[0m \u001b[0;31m# Check the CRC if we're at the end of the file\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 957\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_eof\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_running_crc\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_expected_crc\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], @@ -522,8 +527,12 @@ }, { "cell_type": "code", - "execution_count": 7, - "metadata": {}, + "execution_count": null, + "metadata": { + "jupyter": { + "source_hidden": true + } + }, "outputs": [], "source": [ "train_regions_mask = np.isin(_metadata_df['chr'], _train_chroms)\n", @@ -566,7 +575,11 @@ { "cell_type": "code", "execution_count": 325, - "metadata": {}, + "metadata": { + "jupyter": { + "source_hidden": true + } + }, "outputs": [], "source": [ "def get_random_label_vec(\n", @@ -609,7 +622,11 @@ { "cell_type": "code", "execution_count": 24, - "metadata": {}, + "metadata": { + "jupyter": { + "source_hidden": true + } + }, "outputs": [], "source": [ "import os, time\n", @@ -850,7 +867,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -858,38 +875,12 @@ "output_type": "stream", "text": [ "Train data...\n", - " celltype = H1-hESC: n = 5314\n", - " celltype = HCT116: n = 4759\n", - " celltype = HeLa-S3: n = 4635\n", - " celltype = HepG2: n = 4459\n", - " celltype = K562: n = 5169\n", - " celltype = A549: n = 0\n", - " celltype = GM12878: n = 0\n", - "Validation (ID) data...\n", - " celltype = H1-hESC: n = 6872\n", - " celltype = HCT116: n = 6315\n", - " celltype = HeLa-S3: n = 4219\n", - " celltype = HepG2: n = 8356\n", - " celltype = K562: n = 6538\n", - " celltype = A549: n = 0\n", - " celltype = GM12878: n = 0\n", + " n = 64993\n", + "Validation data...\n", + " n = 4860\n", "Test data...\n", - " celltype = H1-hESC: n = 0\n", - " celltype = HCT116: n = 0\n", - " celltype = HeLa-S3: n = 0\n", - " celltype = HepG2: n = 0\n", - " celltype = K562: n = 0\n", - " celltype = A549: n = 0\n", - " celltype = GM12878: n = 4487\n", - "Validation (OOD) data...\n", - " celltype = H1-hESC: n = 0\n", - " celltype = HCT116: n = 0\n", - " celltype = HeLa-S3: n = 0\n", - " celltype = HepG2: n = 0\n", - " celltype = K562: n = 0\n", - " celltype = A549: n = 6728\n", - " celltype = GM12878: n = 0\n", - "Dout: 128\n" + " n = 4742\n", + "Dout: 9\n" ] } ], @@ -969,7 +960,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -979,16 +970,6 @@ "# x = torch.transpose(x, 1, 2)" ] }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "d = algorithm.process_batch(batch)\n", - "# algorithm.loss.compute" - ] - }, { "cell_type": "code", "execution_count": 9, @@ -997,7 +978,7 @@ { "data": { "text/plain": [ - "tensor(0.7212, device='cuda:0', grad_fn=)" + "tensor(0.8208, device='cuda:0', grad_fn=)" ] }, "execution_count": 9, @@ -1006,6 +987,8 @@ } ], "source": [ + "d = algorithm.process_batch(batch)\n", + "\n", "a = algorithm.loss.compute(d['y_pred'], d['y_true'], return_dict=False)\n", "a" ] @@ -1017,141 +1000,8 @@ "outputs": [ { "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
chrstartstopcelltypesplit
39413chr21000320010009600A5493
39414chr2100032000100038400A5493
39415chr2100102400100108800A5493
39416chr2100172800100179200A5493
39417chr2100230400100236800A5493
..................
495287chr3999680010003200K5620
495288chr39997440099980800K5620
495289chr39998080099987200K5620
495290chr39998720099993600K5620
495291chr399993600100000000K5620
\n", - "

67851 rows × 5 columns

\n", - "
" - ], "text/plain": [ - " chr start stop celltype split\n", - "39413 chr2 10003200 10009600 A549 3\n", - "39414 chr2 100032000 100038400 A549 3\n", - "39415 chr2 100102400 100108800 A549 3\n", - "39416 chr2 100172800 100179200 A549 3\n", - "39417 chr2 100230400 100236800 A549 3\n", - "... ... ... ... ... ...\n", - "495287 chr3 9996800 10003200 K562 0\n", - "495288 chr3 99974400 99980800 K562 0\n", - "495289 chr3 99980800 99987200 K562 0\n", - "495290 chr3 99987200 99993600 K562 0\n", - "495291 chr3 99993600 100000000 K562 0\n", - "\n", - "[67851 rows x 5 columns]" + "" ] }, "execution_count": 10, @@ -1161,7 +1011,7 @@ ], "source": [ "#np.unique(full_dataset._metadata_df['split'], return_counts=True)\n", - "full_dataset._metadata_df" + "full_dataset" ] }, { @@ -1170,38 +1020,20 @@ "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "(array([0. , 0.5, 1. ], dtype=float32), array([7422683, 1007200, 255045]))" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "np.unique(full_dataset.y_array, return_counts=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "0.8546625832706961" - ] - }, - "execution_count": 26, - "metadata": {}, - "output_type": "execute_result" + "ename": "NameError", + "evalue": "name 'importlib' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m#import importlib\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mimportlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mNameError\u001b[0m: name 'importlib' is not defined" + ] } ], "source": [ - "7422683/8684928" + "#import importlib\n", + "importlib.reload(train)" ] }, { @@ -1223,27 +1055,25 @@ "\n", "Epoch [0]:\n", "\n", - "Train:\n", - "torch.Size([8192]) torch.Size([8192]) torch.Size([64, 128]) torch.Size([64, 128])\n", - "torch.Size([]) torch.Size([8192]) torch.Size([64, 128]) torch.Size([64, 128])\n" + "Train:\n" ] }, { - "ename": "AssertionError", + "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 26\u001b[0;31m train(\n\u001b[0m\u001b[1;32m 27\u001b[0m \u001b[0malgorithm\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0malgorithm\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0mdatasets\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdatasets\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0mbest_val_metric\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 25\u001b[0;31m train(\n\u001b[0m\u001b[1;32m 26\u001b[0m \u001b[0malgorithm\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0malgorithm\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0mdatasets\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdatasets\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/wilds/examples/train.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(algorithm, datasets, general_logger, config, epoch_offset, best_val_metric)\u001b[0m\n\u001b[1;32m 88\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 89\u001b[0m \u001b[0;31m# First run training\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 90\u001b[0;31m \u001b[0mrun_epoch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0malgorithm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdatasets\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'train'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgeneral_logger\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepoch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 91\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[0;31m# Then run val\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/wilds/examples/train.py\u001b[0m in \u001b[0;36mrun_epoch\u001b[0;34m(algorithm, dataset, general_logger, epoch, config, train)\u001b[0m\n\u001b[1;32m 41\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0miterator\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 43\u001b[0;31m \u001b[0mbatch_results\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0malgorithm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 44\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[0mbatch_results\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0malgorithm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mevaluate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/wilds/examples/algorithms/single_model_algorithm.py\u001b[0m in \u001b[0;36mupdate\u001b[0;34m(self, batch)\u001b[0m\n\u001b[1;32m 103\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_update\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 104\u001b[0m \u001b[0;31m# log results\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 105\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate_log\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 106\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msanitize_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 107\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/wilds/examples/algorithms/group_algorithm.py\u001b[0m in \u001b[0;36mupdate_log\u001b[0;34m(self, results)\u001b[0m\n\u001b[1;32m 47\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mm\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlogged_metrics\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mno_group_logging\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 49\u001b[0;31m group_metrics, group_counts, worst_group_metric = m.compute_group_wise(\n\u001b[0m\u001b[1;32m 50\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'y_pred'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'y_true'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/wilds/wilds/common/metrics/metric.py\u001b[0m in \u001b[0;36mcompute_group_wise\u001b[0;34m(self, y_pred, y_true, g, n_groups, return_dict)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mresults\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mdict\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mDictionary\u001b[0m \u001b[0mof\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 113\u001b[0m \"\"\"\n\u001b[0;32m--> 114\u001b[0;31m \u001b[0mgroup_metrics\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgroup_counts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mworst_group_metric\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compute_group_wise\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_groups\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 115\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mreturn_dict\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/wilds/wilds/common/metrics/metric.py\u001b[0m in \u001b[0;36m_compute_group_wise\u001b[0;34m(self, y_pred, y_true, g, n_groups)\u001b[0m\n\u001b[1;32m 234\u001b[0m \u001b[0mflattened_g\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mg\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mindices\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 235\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mflattened_metrics\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mflattened_g\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_pred\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 236\u001b[0;31m \u001b[0mgroup_metrics\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgroup_counts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mavg_over_groups\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mflattened_metrics\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mflattened_g\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_groups\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 237\u001b[0m \u001b[0mworst_group_metric\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mworst\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgroup_metrics\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mgroup_counts\u001b[0m\u001b[0;34m>\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 238\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mgroup_metrics\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgroup_counts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mworst_group_metric\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/wilds/wilds/common/utils.py\u001b[0m in \u001b[0;36mavg_over_groups\u001b[0;34m(v, g, n_groups)\u001b[0m\n\u001b[1;32m 84\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0mg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 85\u001b[0m \u001b[0mdevice\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 86\u001b[0;31m \u001b[0;32massert\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0mg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 87\u001b[0m \u001b[0mgroup_count\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_counts\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_groups\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 88\u001b[0m \u001b[0mgroup_avgs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch_scatter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscatter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msrc\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindex\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mn_groups\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduce\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'mean'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mAssertionError\u001b[0m: " + "\u001b[0;32m~/wilds/examples/algorithms/group_algorithm.py\u001b[0m in \u001b[0;36mupdate_log\u001b[0;34m(self, results)\u001b[0m\n\u001b[1;32m 54\u001b[0m return_dict=False)\n\u001b[1;32m 55\u001b[0m \u001b[0mbatch_log\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34mf'{self.group_prefix}{m.name}'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgroup_metrics\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 56\u001b[0;31m batch_log[m.agg_metric_field] = m.compute(\n\u001b[0m\u001b[1;32m 57\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'y_pred'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'y_true'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/wilds/wilds/common/metrics/metric.py\u001b[0m in \u001b[0;36mcompute\u001b[0;34m(self, y_pred, y_true, return_dict)\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[0magg_metric\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0.\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0my_true\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 87\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 88\u001b[0;31m \u001b[0magg_metric\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compute\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 89\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mreturn_dict\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 90\u001b[0m results = {\n", + "\u001b[0;32m~/wilds/wilds/common/metrics/metric.py\u001b[0m in \u001b[0;36m_compute\u001b[0;34m(self, y_pred, y_true)\u001b[0m\n\u001b[1;32m 224\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 225\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_compute\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 226\u001b[0;31m \u001b[0mflattened_metrics\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompute_flattened\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreturn_dict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 227\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mflattened_metrics\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 228\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0.\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0my_true\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/wilds/wilds/common/metrics/metric.py\u001b[0m in \u001b[0;36mcompute_flattened\u001b[0;34m(self, y_pred, y_true, return_dict)\u001b[0m\n\u001b[1;32m 240\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcompute_flattened\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreturn_dict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 241\u001b[0m \u001b[0mis_labeled\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m~\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misnan\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_true\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 242\u001b[0;31m \u001b[0mbatch_idx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwhere\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mis_labeled\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 243\u001b[0m \u001b[0mflattened_y_pred\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0my_pred\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mis_labeled\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 244\u001b[0m \u001b[0mflattened_y_true\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mis_labeled\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], @@ -1271,8 +1101,7 @@ " if resume_success == False:\n", " epoch_offset=0\n", " best_val_metric=None\n", - "\n", - "\n", + " \n", " train(\n", " algorithm=algorithm,\n", " datasets=datasets,\n", @@ -1324,6 +1153,20 @@ "outputs": [], "source": [] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": null, diff --git a/wilds/common/metrics/all_metrics.py b/wilds/common/metrics/all_metrics.py index 7a3d6975..01c81111 100644 --- a/wilds/common/metrics/all_metrics.py +++ b/wilds/common/metrics/all_metrics.py @@ -71,17 +71,20 @@ def __init__(self, prediction_fn=logits_to_binary_pred, name=None, average='macr def _compute_flattened(self, flattened_y_pred, flattened_y_true): if self.prediction_fn is not None: flattened_y_pred = self.prediction_fn(flattened_y_pred) + ytr = np.array(flattened_y_true.squeeze().detach().cpu().numpy() > 0) + ypr = flattened_y_pred.squeeze().detach().cpu().numpy() score = sklearn.metrics.average_precision_score( - np.array(flattened_y_true.squeeze().detach().cpu().numpy() > 0), - flattened_y_pred.squeeze().detach().cpu().numpy(), + ytr, + ypr, average=self.average ) - return torch.tensor(score).to(flattened_y_pred.device) + to_ret = torch.tensor(score).to(flattened_y_pred.device) + print("why ", ytr, ytr.shape, ypr, ypr.shape, score, to_ret) + return to_ret def worst(self, metrics): return minimum(metrics) - class Recall(Metric): def __init__(self, prediction_fn=None, name=None, average='binary'): self.prediction_fn = prediction_fn diff --git a/wilds/common/metrics/metric.py b/wilds/common/metrics/metric.py index 2bc8237e..2e2b0c8f 100644 --- a/wilds/common/metrics/metric.py +++ b/wilds/common/metrics/metric.py @@ -232,7 +232,7 @@ def _compute(self, y_pred, y_true): def _compute_group_wise(self, y_pred, y_true, g, n_groups): flattened_metrics, indices = self.compute_flattened(y_pred, y_true, return_dict=False) flattened_g = g[indices] - print(flattened_metrics.shape, flattened_g.shape, y_pred.shape, y_true.shape) + print(flattened_metrics.shape, flattened_g.shape, (indices > 0).sum(), y_pred.shape, y_true.shape) group_metrics, group_counts = avg_over_groups(flattened_metrics, flattened_g, n_groups) worst_group_metric = self.worst(group_metrics[group_counts>0]) return group_metrics, group_counts, worst_group_metric From 3b5f8f5e8ecc0d55e267ccab2b2de1b58a42824a Mon Sep 17 00:00:00 2001 From: aikanor Date: Fri, 19 Mar 2021 20:17:48 -0700 Subject: [PATCH 077/244] integration besides eval 2/ --- examples/sbox_run_expt.ipynb | 494 +++++++++++----------------- wilds/common/metrics/all_metrics.py | 30 ++ 2 files changed, 221 insertions(+), 303 deletions(-) diff --git a/examples/sbox_run_expt.ipynb b/examples/sbox_run_expt.ipynb index 86525331..2aad102b 100644 --- a/examples/sbox_run_expt.ipynb +++ b/examples/sbox_run_expt.ipynb @@ -11,21 +11,14 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "396.69921875\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:root:The OGB package is out of date. Your version is 1.2.4, while the latest version is 1.3.0.\n" + "47.42578125\n" ] } ], @@ -113,7 +106,7 @@ { "data": { "text/plain": [ - "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" + "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" ] }, "execution_count": 2, @@ -229,7 +222,7 @@ "\n", "config = config_camelyon\n", "config = config_encode\n", - "config = config_bdd100k\n" + "# config = config_bdd100k\n" ] }, { @@ -262,20 +255,9 @@ "cell_type": "code", "execution_count": 5, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Namespace(algo_log_metric='multitask_accuracy', algorithm='ERM', batch_size=32, coral_penalty_weight=None, dataset='bdd100k', dataset_kwargs={}, device=0, distinct_groups=None, download=False, eval_epoch=None, eval_loader='standard', eval_only=False, eval_splits=[], eval_transform='image_base', evaluate_all_splits=True, frac=1.0, group_dro_step_size=None, groupby_fields=None, irm_lambda=None, irm_penalty_anneal_iters=None, loader_kwargs={'num_workers': 1, 'pin_memory': True}, log_dir='./logs', log_every=50, loss_function='multitask_bce', lr=0.001, max_grad_norm=None, max_token_length=None, model='resnet50', model_kwargs={'pretrained': False}, n_epochs=10, n_groups_per_batch=4, no_group_logging=True, optimizer='SGD', optimizer_kwargs={'momentum': 0.9}, progress_bar=False, resize_scale=None, resume=False, root_dir='data', save_best=True, save_last=True, save_step=None, scheduler=None, scheduler_kwargs={}, scheduler_metric_name=None, scheduler_metric_split='val', seed=0, split_scheme='official', target_resolution=(224, 224), train_loader='standard', train_transform='image_base', uniform_over_groups=False, use_wandb=False, val_metric='acc_all', val_metric_decreasing=False, weight_decay=0.0001)" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "config#.optimizer_kwargs = {}" + "config.optimizer_kwargs = {}" ] }, { @@ -287,7 +269,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Dataset: bdd100k\n", + "Dataset: encode-tfbs\n", "Algorithm: ERM\n", "Root dir: data\n", "Split scheme: official\n", @@ -298,31 +280,31 @@ "Train loader: standard\n", "Uniform over groups: False\n", "Distinct groups: None\n", - "N groups per batch: 4\n", - "Batch size: 32\n", + "N groups per batch: 2\n", + "Batch size: 64\n", "Eval loader: standard\n", - "Model: resnet50\n", + "Model: leopard\n", "Model kwargs: {'pretrained': False}\n", - "Train transform: image_base\n", - "Eval transform: image_base\n", - "Target resolution: (224, 224)\n", + "Train transform: None\n", + "Eval transform: None\n", + "Target resolution: None\n", "Resize scale: None\n", "Max token length: None\n", "Loss function: multitask_bce\n", - "Groupby fields: None\n", + "Groupby fields: ['celltype']\n", "Group dro step size: None\n", "Coral penalty weight: None\n", "Irm lambda: None\n", "Irm penalty anneal iters: None\n", - "Algo log metric: multitask_accuracy\n", - "Val metric: acc_all\n", + "Algo log metric: multitask_avgprec\n", + "Val metric: acc_avg\n", "Val metric decreasing: False\n", - "N epochs: 10\n", - "Optimizer: SGD\n", + "N epochs: 5\n", + "Optimizer: Adam\n", "Lr: 0.001\n", - "Weight decay: 0.0001\n", + "Weight decay: 0.01\n", "Max grad norm: None\n", - "Optimizer kwargs: {'momentum': 0.9}\n", + "Optimizer kwargs: {}\n", "Scheduler: None\n", "Scheduler kwargs: {}\n", "Scheduler metric split: val\n", @@ -338,11 +320,14 @@ "Save step: None\n", "Save best: True\n", "Save last: True\n", - "No group logging: True\n", + "No group logging: False\n", "Use wandb: False\n", "Progress bar: False\n", "Resume: False\n", - "\n" + "\n", + "chr3 2.979121685028076\n", + "chr2 6.626891374588013\n", + "chr1 10.355815410614014\n" ] } ], @@ -612,252 +597,6 @@ " return mdf, y_label_vec" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Dataset object (long version)" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": { - "jupyter": { - "source_hidden": true - } - }, - "outputs": [], - "source": [ - "import os, time\n", - "import torch\n", - "import pandas as pd\n", - "import numpy as np\n", - "from wilds.datasets.wilds_dataset import WILDSDataset\n", - "from wilds.common.grouper import CombinatorialGrouper\n", - "from wilds.common.metrics.all_metrics import Accuracy\n", - "\n", - "class EncodeTFBSDataset(WILDSDataset):\n", - " \"\"\"\n", - " ENCODE-DREAM-wilds dataset of transcription factor binding sites. \n", - " This is a subset of the dataset from the ENCODE-DREAM in vivo Transcription Factor Binding Site Prediction Challenge. \n", - " \n", - " Input (x):\n", - " 1000-base-pair regions of sequence with a quantified chromatin accessibility readout.\n", - "\n", - " Label (y):\n", - " y is binary. It is 1 if the central 200bp region is bound by the transcription factor MAX, and 0 otherwise.\n", - "\n", - " Metadata:\n", - " Each sequence is annotated with the celltype of origin (a string) and the chromosome of origin (a string).\n", - " \n", - " Website:\n", - " https://www.synapse.org/#!Synapse:syn6131484\n", - " \"\"\"\n", - "\n", - " def __init__(self, root_dir='data', download=False, split_scheme='official'):\n", - " itime = time.time()\n", - " self._dataset_name = 'encode-tfbs'\n", - " self._version = '1.0'\n", - " self._download_url = 'https://worksheets.codalab.org/rest/bundles/0x8b3255e21e164cd98d3aeec09cd0bc26/contents/blob/'\n", - " self._data_dir = self.initialize_data_dir(root_dir, download)\n", - " self._y_size = 128\n", - " # self._n_classes = 2\n", - " \n", - " self._train_chroms = ['chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr10', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX']\n", - " self._val_chroms = ['chr2', 'chr9', 'chr11']\n", - " self._test_chroms = ['chr1', 'chr8', 'chr21']\n", - " self._transcription_factor = 'MAX'\n", - " self._train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']\n", - " self._val_celltype = ['A549']\n", - " self._test_celltype = ['GM12878']\n", - " self._all_chroms = self._train_chroms + self._val_chroms + self._test_chroms\n", - " self._all_celltypes = self._train_celltypes + self._val_celltype + self._test_celltype\n", - " \n", - " self._metadata_map = {}\n", - " self._metadata_map['chr'] = self._all_chroms\n", - " self._metadata_map['celltype'] = self._all_celltypes\n", - " \n", - " # Get the splits\n", - " if split_scheme=='official':\n", - " split_scheme = 'standard'\n", - " \n", - " self._split_scheme = split_scheme\n", - " self._split_dict = {\n", - " 'train': 0,\n", - " 'id_val': 1,\n", - " 'test': 2,\n", - " 'val': 3\n", - " }\n", - " self._split_names = {\n", - " 'train': 'Train',\n", - " 'id_val': 'Validation (ID)',\n", - " 'test': 'Test',\n", - " 'val': 'Validation (OOD)',\n", - " }\n", - " \n", - " # Load sequence and DNase features\n", - " sequence_filename = os.path.join(self._data_dir, 'sequence.npz')\n", - " seq_arr = np.load(sequence_filename)\n", - " self._seq_bp = {}\n", - " for chrom in self._all_chroms: #seq_arr:\n", - " self._seq_bp[chrom] = seq_arr[chrom]\n", - " print(chrom, time.time() - itime)\n", - " \n", - " self._dnase_allcelltypes = {}\n", - " ct = 'avg'\n", - " dnase_avg_bw_path = os.path.join(self._data_dir, 'Leopard_dnase/{}.bigwig'.format(ct))\n", - " self._dnase_allcelltypes[ct] = pyBigWig.open(dnase_avg_bw_path)\n", - " for ct in self._all_celltypes:\n", - " \"\"\"\n", - " dnase_filename = os.path.join(self._data_dir, '{}_dnase.npz'.format(ct))\n", - " dnase_npz_contents = np.load(dnase_filename)\n", - " self._dnase_allcelltypes[ct] = {}\n", - " for chrom in self._all_chroms: #self._seq_bp:\n", - " self._dnase_allcelltypes[ct][chrom] = dnase_npz_contents[chrom]\n", - " \"\"\"\n", - " dnase_bw_path = os.path.join(self._data_dir, 'Leopard_dnase/{}.bigwig'.format(ct))\n", - " self._dnase_allcelltypes[ct] = pyBigWig.open(dnase_bw_path)\n", - " \n", - " self._metadata_df = pd.read_csv(\n", - " self._data_dir + '/labels/MAX/metadata_df.bed', sep='\\t', header=None, \n", - " index_col=None, names=['chr', 'start', 'stop', 'celltype']\n", - " )\n", - " \n", - " train_regions_mask = np.isin(self._metadata_df['chr'], self._train_chroms)\n", - " val_regions_mask = np.isin(self._metadata_df['chr'], self._val_chroms)\n", - " test_regions_mask = np.isin(self._metadata_df['chr'], self._test_chroms)\n", - " train_celltype_mask = np.isin(self._metadata_df['celltype'], self._train_celltypes)\n", - " val_celltype_mask = np.isin(self._metadata_df['celltype'], self._val_celltype)\n", - " test_celltype_mask = np.isin(self._metadata_df['celltype'], self._test_celltype)\n", - " \n", - " split_array = -1*np.ones(self._metadata_df.shape[0]).astype(int)\n", - " split_array[np.logical_and(train_regions_mask, train_celltype_mask)] = self._split_dict['train']\n", - " split_array[np.logical_and(test_regions_mask, test_celltype_mask)] = self._split_dict['test']\n", - " # Validate using validation chr, either using a designated validation cell line ('val') or a training cell line ('id_val')\n", - " split_array[np.logical_and(val_regions_mask, val_celltype_mask)] = self._split_dict['val']\n", - " split_array[np.logical_and(val_regions_mask, train_celltype_mask)] = self._split_dict['id_val']\n", - " \n", - " if self._split_scheme=='standard':\n", - " self._metadata_df.insert(len(self._metadata_df.columns), 'split', split_array)\n", - " else:\n", - " raise ValueError(f'Split scheme {self._split_scheme} not recognized')\n", - " \n", - " metadata_mask = (self._metadata_df['split'] != -1)\n", - " self._metadata_df = self._metadata_df[self._metadata_df['split'] != -1]\n", - " \n", - " chr_ints = self._metadata_df['chr'].replace(dict( [(y, x) for x, y in enumerate(self._metadata_map['chr'])] )).values\n", - " celltype_ints = self._metadata_df['celltype'].replace(dict( [(y, x) for x, y in enumerate(self._metadata_map['celltype'])] )).values\n", - " self._split_array = self._metadata_df['split'].values\n", - " self._y_array = torch.Tensor(np.load(self._data_dir + '/labels/MAX/metadata_y.npy'))\n", - " self._y_array = self._y_array[metadata_mask]\n", - " \n", - " self._metadata_array = torch.stack(\n", - " (torch.LongTensor(chr_ints), \n", - " torch.LongTensor(celltype_ints)\n", - " ),\n", - " dim=1)\n", - " self._metadata_fields = ['chr', 'celltype']\n", - " \n", - " self._eval_grouper = CombinatorialGrouper(\n", - " dataset=self,\n", - " groupby_fields=['celltype'])\n", - " \n", - " self._metric = Accuracy()\n", - " \n", - " super().__init__(root_dir, download, split_scheme)\n", - " \n", - " \"\"\"\n", - " def get_random_label_vec(metadata_df, output_size=128):\n", - " # Sample a positively labeled region at random\n", - " pos_mdf = metadata_df[metadata_df['y'] == 1] #.iloc[ metadata_df['chr'] == s['chr'], : ]\n", - " pos_seed_region = pos_mdf.iloc[np.random.randint(pos_mdf.shape[0])]\n", - "\n", - " # Extract regions from this chromosome in this celltype, to get a window of labels from\n", - " chr_msk = np.array(metadata_df['chr']) == pos_seed_region['chr']\n", - " ct_msk = np.array(metadata_df['celltype']) == pos_seed_region['celltype']\n", - " mdf = metadata_df[chr_msk & ct_msk]\n", - "\n", - " # Get labels\n", - " start_ndx = np.where(mdf['start'] == pos_seed_region['start'])[0][0]\n", - " y_label_vec = mdf.iloc[start_ndx:start_ndx+output_size, :]['y']\n", - " \"\"\"\n", - " \n", - " def get_input(self, idx, window_size=12800):\n", - " \"\"\"\n", - " Returns x for a given idx in metadata_array, which has been filtered to only take windows with the desired stride.\n", - " Computes this from: \n", - " (1) sequence features in self._seq_bp\n", - " (2) DNase bigwig file handles in self._dnase_allcelltypes\n", - " (3) Metadata for the index (location along the genome with 6400bp window width)\n", - " (4) Window_size, the length of sequence returned (centered on the 6400bp region in (3))\n", - " \"\"\"\n", - " this_metadata = self._metadata_df.iloc[idx, :]\n", - " interval_start = this_metadata['start'] - int(window_size/4)\n", - " interval_end = interval_start + window_size #this_metadata['stop']\n", - " seq_this = self._seq_bp[this_metadata['chr']][interval_start:interval_end]\n", - " dnase_bw = self._dnase_allcelltypes[this_metadata['celltype']]\n", - " dnase_this = dnase_bw.values(chrom, interval_start, interval_end, numpy=True)\n", - " dnase_avg = self._dnase_allcelltypes['avg'].values(chrom, interval_start, interval_end, numpy=True)\n", - " return torch.tensor(np.column_stack(\n", - " [np.nan_to_num(seq_this), np.nan_to_num(dnase_this), np.nan_to_num(dnase_avg)]\n", - " ))\n", - "\n", - " def eval(self, y_pred, y_true, metadata):\n", - " return self.standard_group_eval(\n", - " self._metric,\n", - " self._eval_grouper,\n", - " y_pred, y_true, metadata)" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "metadata": { - "collapsed": true, - "jupyter": { - "outputs_hidden": true - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "chr3 3.0425407886505127\n", - "chr4 5.967821359634399\n", - "chr5 8.747126340866089\n", - "chr6 11.370141744613647\n", - "chr7 13.802208423614502\n", - "chr10 15.875979900360107\n", - "chr12 17.929850339889526\n", - "chr13 19.67976665496826\n", - "chr14 21.306750059127808\n", - "chr15 22.866544723510742\n", - "chr16 24.241100788116455\n", - "chr17 25.480982303619385\n", - "chr18 26.677065134048462\n", - "chr19 27.579110622406006\n", - "chr20 28.545915603637695\n", - "chr22 29.323810577392578\n", - "chrX 31.698036670684814\n", - "chr2 35.40705943107605\n", - "chr9 37.5518524646759\n", - "chr11 39.61783218383789\n", - "chr1 43.411964893341064\n", - "chr8 45.64823389053345\n", - "chr21 46.377281188964844\n" - ] - } - ], - "source": [ - "full_dataset_encode = EncodeTFBSDataset(\n", - " root_dir=config.root_dir,\n", - " download=config.download,\n", - " split_scheme=config.split_scheme,\n", - " **config.dataset_kwargs)" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -875,12 +614,38 @@ "output_type": "stream", "text": [ "Train data...\n", - " n = 64993\n", - "Validation data...\n", - " n = 4860\n", + " celltype = H1-hESC: n = 5314\n", + " celltype = HCT116: n = 4759\n", + " celltype = HeLa-S3: n = 4635\n", + " celltype = HepG2: n = 4459\n", + " celltype = K562: n = 5169\n", + " celltype = A549: n = 0\n", + " celltype = GM12878: n = 0\n", + "Validation (ID) data...\n", + " celltype = H1-hESC: n = 6872\n", + " celltype = HCT116: n = 6315\n", + " celltype = HeLa-S3: n = 4219\n", + " celltype = HepG2: n = 8356\n", + " celltype = K562: n = 6538\n", + " celltype = A549: n = 0\n", + " celltype = GM12878: n = 0\n", "Test data...\n", - " n = 4742\n", - "Dout: 9\n" + " celltype = H1-hESC: n = 0\n", + " celltype = HCT116: n = 0\n", + " celltype = HeLa-S3: n = 0\n", + " celltype = HepG2: n = 0\n", + " celltype = K562: n = 0\n", + " celltype = A549: n = 0\n", + " celltype = GM12878: n = 4487\n", + "Validation (OOD) data...\n", + " celltype = H1-hESC: n = 0\n", + " celltype = HCT116: n = 0\n", + " celltype = HeLa-S3: n = 0\n", + " celltype = HepG2: n = 0\n", + " celltype = K562: n = 0\n", + " celltype = A549: n = 6728\n", + " celltype = GM12878: n = 0\n", + "Dout: 128\n" ] } ], @@ -978,7 +743,7 @@ { "data": { "text/plain": [ - "tensor(0.8208, device='cuda:0', grad_fn=)" + "tensor(0.7212, device='cuda:0', grad_fn=)" ] }, "execution_count": 9, @@ -1001,7 +766,7 @@ { "data": { "text/plain": [ - "" + "torch.Size([64, 128])" ] }, "execution_count": 10, @@ -1011,7 +776,7 @@ ], "source": [ "#np.unique(full_dataset._metadata_df['split'], return_counts=True)\n", - "full_dataset" + "y_true.squeeze().shape" ] }, { @@ -1036,6 +801,26 @@ "importlib.reload(train)" ] }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Namespace(algo_log_metric='multitask_avgprec', algorithm='ERM', batch_size=64, coral_penalty_weight=None, dataset='encode-tfbs', dataset_kwargs={}, device=device(type='cuda', index=0), distinct_groups=None, download=False, eval_epoch=None, eval_loader='standard', eval_only=False, eval_splits=[], eval_transform=None, evaluate_all_splits=True, frac=1.0, group_dro_step_size=None, groupby_fields=['celltype'], irm_lambda=None, irm_penalty_anneal_iters=None, loader_kwargs={'num_workers': 1, 'pin_memory': True}, log_dir='./logs', log_every=50, loss_function='multitask_bce', lr=0.001, max_grad_norm=None, max_token_length=None, model='leopard', model_kwargs={'pretrained': False}, n_epochs=5, n_groups_per_batch=2, no_group_logging=False, optimizer='Adam', optimizer_kwargs={}, progress_bar=False, resize_scale=None, resume=False, root_dir='data', save_best=True, save_last=True, save_step=None, scheduler=None, scheduler_kwargs={}, scheduler_metric_name=None, scheduler_metric_split='val', seed=0, split_scheme='official', target_resolution=None, train_loader='standard', train_transform=None, uniform_over_groups=False, use_wandb=False, val_metric='acc_avg', val_metric_decreasing=False, weight_decay=0.01)" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "config" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -1055,25 +840,128 @@ "\n", "Epoch [0]:\n", "\n", - "Train:\n" + "Train:\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (2432,) [1 0 1 ... 1 1 0] (2432,) 0.09923777357272781 tensor(0.0992, dtype=torch.float64)\n", + "why [False False False ... False False False] (1792,) [1 1 0 ... 1 0 1] (1792,) 0.18020602071676678 tensor(0.1802, dtype=torch.float64)\n", + "why [False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False True\n", + " True True True True True True True True True True True False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False True\n", + " True True True True True True True True True True True False\n", + " False True True True True True True True True True True True\n", + " True True True True True False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False True True True True True True True True True\n", + " True True True False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False True True\n", + " True True True True True True True True True True False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False True True True True True True True True True\n", + " True True True False False False False False False False False False\n", + " False False False False True True True True True True True True\n", + " True True True False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False True True True True True True True\n", + " True True True True True False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False] (896,) [1 0 1 0 0 1 1 0 1 0 0 1 0 1 1 1 0 1 1 0 1 0 1 1 0 0 1 1 1 1 1 1 0 1 0 0 0\n", + " 1 0 1 1 0 0 0 1 1 1 0 1 0 1 0 0 0 1 0 0 0 0 0 0 0 0 0 1 0 0 1 0 0 1 0 0 0\n", + " 0 1 1 1 0 1 1 0 1 0 1 0 0 1 0 0 1 1 0 0 0 0 1 1 0 0 0 0 1 0 0 1 0 0 0 0 0\n", + " 1 1 0 0 1 1 0 1 0 1 0 0 1 0 1 0 1 1 0 0 0 1 1 1 1 1 0 1 0 0 0 1 0 0 1 1 0\n", + " 1 0 0 1 0 0 0 0 1 1 1 1 0 0 0 0 0 1 0 1 0 1 1 1 1 0 1 1 0 0 1 1 1 1 1 1 0\n", + " 0 0 1 1 1 1 1 1 1 1 0 0 1 0 1 0 0 0 1 1 1 0 1 1 1 1 1 1 0 1 0 1 1 0 1 1 0\n", + " 1 0 1 1 1 1 1 0 0 1 1 1 0 0 1 1 0 0 1 0 1 0 0 1 1 1 1 0 1 1 1 1 1 0 1 1 1\n", + " 1 0 0 0 1 0 0 1 0 1 1 1 0 1 1 1 0 0 1 1 0 0 0 1 0 0 1 1 0 1 0 0 0 1 0 0 0\n", + " 1 1 0 1 0 1 1 0 0 1 0 1 0 1 1 1 1 0 1 0 1 0 1 0 0 1 1 0 1 1 0 1 1 1 1 0 0\n", + " 1 1 0 1 0 0 1 1 1 0 0 0 1 0 0 0 0 0 1 0 0 1 1 0 0 0 0 0 1 1 0 1 1 1 1 0 1\n", + " 0 0 0 1 1 0 0 0 0 1 1 1 0 1 1 1 0 1 0 1 0 0 0 0 0 1 1 1 1 1 0 1 0 1 0 0 1\n", + " 1 0 0 1 1 0 0 0 1 1 1 1 0 1 1 0 1 1 1 0 0 1 0 1 0 0 1 1 0 0 0 0 0 0 1 0 0\n", + " 0 0 1 0 1 0 1 0 1 1 1 1 1 0 1 1 1 0 1 0 1 1 0 0 0 0 1 0 1 1 0 1 0 1 1 1 0\n", + " 0 1 0 1 1 1 0 0 1 0 0 1 0 1 1 0 1 0 1 0 0 1 0 1 1 0 1 0 1 0 0 1 1 1 0 1 0\n", + " 1 0 1 1 1 1 1 0 0 1 0 1 0 0 0 0 1 0 0 0 0 0 1 0 0 0 1 0 1 1 0 0 0 0 0 1 0\n", + " 1 1 0 0 0 0 1 1 0 1 0 0 0 0 1 0 1 1 0 1 1 1 0 1 1 0 0 0 0 1 1 1 0 1 0 1 1\n", + " 0 0 0 0 1 1 1 1 1 0 0 1 0 1 0 1 1 0 1 1 1 1 1 1 1 1 0 0 1 1 1 1 1 1 1 1 1\n", + " 1 0 1 1 1 0 1 0 0 0 0 1 0 0 1 1 1 0 1 1 1 1 0 0 0 0 0 1 1 1 0 0 1 0 1 0 0\n", + " 1 0 0 1 0 1 0 1 1 1 0 1 1 0 1 1 0 0 1 0 0 1 1 1 1 0 1 1 0 1 1 1 0 1 1 0 1\n", + " 0 0 1 1 1 1 0 1 0 0 1 1 1 0 1 1 1 1 0 0 1 0 1 0 0 0 1 1 0 1 0 0 1 0 1 0 0\n", + " 1 0 1 1 0 1 1 1 1 0 0 1 0 0 1 1 1 1 0 0 1 1 1 1 0 0 1 0 0 1 0 1 0 1 1 1 0\n", + " 1 0 1 1 0 0 0 0 1 0 0 0 1 1 1 1 0 0 0 1 1 0 1 1 0 0 1 0 0 1 1 1 1 0 0 0 1\n", + " 0 1 0 1 1 0 1 0 0 0 1 1 1 1 1 1 0 1 0 1 1 0 1 0 0 1 1 1 1 0 1 1 0 1 1 1 1\n", + " 1 1 0 0 0 0 0 1 0 1 0 0 0 1 0 1 0 0 1 0 1 1 0 1 0 1 1 1 0 1 1 1 1 0 0 1 1\n", + " 1 1 1 0 1 1 0 1] (896,) 0.12653340353855683 tensor(0.1265, dtype=torch.float64)\n", + "why [False False False ... False False False] (1152,) [0 0 0 ... 1 1 0] (1152,) 0.15009138463477656 tensor(0.1501, dtype=torch.float64)\n", + "why [ True True True ... True True True] (1920,) [0 0 1 ... 1 0 0] (1920,) 0.13893378955027236 tensor(0.1389, dtype=torch.float64)\n" ] }, { - "ename": "KeyboardInterrupt", - "evalue": "", + "ename": "RuntimeError", + "evalue": "All input tensors must be on the same device. Received cpu and cuda:0", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0mbest_val_metric\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 25\u001b[0;31m train(\n\u001b[0m\u001b[1;32m 26\u001b[0m \u001b[0malgorithm\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0malgorithm\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0mdatasets\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdatasets\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/wilds/examples/train.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(algorithm, datasets, general_logger, config, epoch_offset, best_val_metric)\u001b[0m\n\u001b[1;32m 88\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 89\u001b[0m \u001b[0;31m# First run training\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 90\u001b[0;31m \u001b[0mrun_epoch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0malgorithm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdatasets\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'train'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgeneral_logger\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepoch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 91\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[0;31m# Then run val\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/wilds/examples/train.py\u001b[0m in \u001b[0;36mrun_epoch\u001b[0;34m(algorithm, dataset, general_logger, epoch, config, train)\u001b[0m\n\u001b[1;32m 41\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0miterator\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 43\u001b[0;31m \u001b[0mbatch_results\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0malgorithm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 44\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[0mbatch_results\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0malgorithm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mevaluate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/wilds/examples/algorithms/single_model_algorithm.py\u001b[0m in \u001b[0;36mupdate\u001b[0;34m(self, batch)\u001b[0m\n\u001b[1;32m 103\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_update\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 104\u001b[0m \u001b[0;31m# log results\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 105\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate_log\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 106\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msanitize_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 107\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/wilds/examples/algorithms/group_algorithm.py\u001b[0m in \u001b[0;36mupdate_log\u001b[0;34m(self, results)\u001b[0m\n\u001b[1;32m 54\u001b[0m return_dict=False)\n\u001b[1;32m 55\u001b[0m \u001b[0mbatch_log\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34mf'{self.group_prefix}{m.name}'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgroup_metrics\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 56\u001b[0;31m batch_log[m.agg_metric_field] = m.compute(\n\u001b[0m\u001b[1;32m 57\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'y_pred'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 58\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'y_true'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/wilds/wilds/common/metrics/metric.py\u001b[0m in \u001b[0;36mcompute\u001b[0;34m(self, y_pred, y_true, return_dict)\u001b[0m\n\u001b[1;32m 86\u001b[0m \u001b[0magg_metric\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0.\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0my_true\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 87\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 88\u001b[0;31m \u001b[0magg_metric\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compute\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 89\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mreturn_dict\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 90\u001b[0m results = {\n", - "\u001b[0;32m~/wilds/wilds/common/metrics/metric.py\u001b[0m in \u001b[0;36m_compute\u001b[0;34m(self, y_pred, y_true)\u001b[0m\n\u001b[1;32m 224\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 225\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_compute\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 226\u001b[0;31m \u001b[0mflattened_metrics\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompute_flattened\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreturn_dict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 227\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mflattened_metrics\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m==\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 228\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0.\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0my_true\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/wilds/wilds/common/metrics/metric.py\u001b[0m in \u001b[0;36mcompute_flattened\u001b[0;34m(self, y_pred, y_true, return_dict)\u001b[0m\n\u001b[1;32m 240\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcompute_flattened\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreturn_dict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 241\u001b[0m \u001b[0mis_labeled\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m~\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misnan\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_true\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 242\u001b[0;31m \u001b[0mbatch_idx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwhere\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mis_labeled\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 243\u001b[0m \u001b[0mflattened_y_pred\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0my_pred\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mis_labeled\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 244\u001b[0m \u001b[0mflattened_y_true\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mis_labeled\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + "\u001b[0;32m~/wilds/examples/algorithms/group_algorithm.py\u001b[0m in \u001b[0;36mupdate_log\u001b[0;34m(self, results)\u001b[0m\n\u001b[1;32m 47\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mm\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlogged_metrics\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mno_group_logging\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 49\u001b[0;31m group_metrics, group_counts, worst_group_metric = m.compute_group_wise(\n\u001b[0m\u001b[1;32m 50\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'y_pred'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'y_true'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/wilds/wilds/common/metrics/metric.py\u001b[0m in \u001b[0;36mcompute_group_wise\u001b[0;34m(self, y_pred, y_true, g, n_groups, return_dict)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mresults\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mdict\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mDictionary\u001b[0m \u001b[0mof\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 113\u001b[0m \"\"\"\n\u001b[0;32m--> 114\u001b[0;31m \u001b[0mgroup_metrics\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgroup_counts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mworst_group_metric\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compute_group_wise\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_groups\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 115\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mreturn_dict\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/wilds/wilds/common/metrics/metric.py\u001b[0m in \u001b[0;36m_compute_group_wise\u001b[0;34m(self, y_pred, y_true, g, n_groups)\u001b[0m\n\u001b[1;32m 134\u001b[0m \u001b[0my_pred\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mg\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mgroup_idx\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 135\u001b[0m y_true[g == group_idx]))\n\u001b[0;32m--> 136\u001b[0;31m \u001b[0mgroup_metrics\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstack\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgroup_metrics\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 137\u001b[0m \u001b[0mworst_group_metric\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mworst\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgroup_metrics\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mgroup_counts\u001b[0m\u001b[0;34m>\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 138\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mRuntimeError\u001b[0m: All input tensors must be on the same device. Received cpu and cuda:0" ] } ], diff --git a/wilds/common/metrics/all_metrics.py b/wilds/common/metrics/all_metrics.py index 01c81111..cc87ee5f 100644 --- a/wilds/common/metrics/all_metrics.py +++ b/wilds/common/metrics/all_metrics.py @@ -81,6 +81,9 @@ def _compute_flattened(self, flattened_y_pred, flattened_y_true): to_ret = torch.tensor(score).to(flattened_y_pred.device) print("why ", ytr, ytr.shape, ypr, ypr.shape, score, to_ret) return to_ret + + def _compute(self, y_pred, y_true): + return self._compute_flattened(y_pred, y_true) def worst(self, metrics): return minimum(metrics) @@ -127,6 +130,33 @@ def _compute(self, y_pred, y_true): def worst(self, metrics): return minimum(metrics) +class MTAveragePrecision(Metric): + def __init__(self, prediction_fn=logits_to_binary_pred, name=None, average='macro'): + self.prediction_fn = prediction_fn + if name is None: + name = f'avgprec' + if average is not None: + name+=f'-{average}' + self.average = average + super().__init__(name=name) + + def _compute(self, y_pred, y_true): + if self.prediction_fn is not None: + y_pred = self.prediction_fn(y_pred) + ytr = np.array(torch.flatten(y_true.squeeze()).detach().cpu().numpy() > 0) + ypr = torch.flatten(y_pred.squeeze()).detach().cpu().numpy() + score = sklearn.metrics.average_precision_score( + ytr, + ypr, + average=self.average + ) + to_ret = torch.tensor(score)#.to(flattened_y_pred.device) + print("why ", ytr, ytr.shape, ypr, ypr.shape, score, to_ret) + return to_ret + + def worst(self, metrics): + return minimum(metrics) + class F1(Metric): def __init__(self, prediction_fn=None, name=None, average='binary'): self.prediction_fn = prediction_fn From 6f9ffeba3f2de74274c5b1f738444542acdaeb1c Mon Sep 17 00:00:00 2001 From: aikanor Date: Fri, 19 Mar 2021 20:34:39 -0700 Subject: [PATCH 078/244] integration besides eval 3/ --- examples/sbox_run_expt.ipynb | 1117 ++++++++++++++++++++++++++- wilds/common/metrics/all_metrics.py | 3 +- 2 files changed, 1097 insertions(+), 23 deletions(-) diff --git a/examples/sbox_run_expt.ipynb b/examples/sbox_run_expt.ipynb index 2aad102b..4eeaee7a 100644 --- a/examples/sbox_run_expt.ipynb +++ b/examples/sbox_run_expt.ipynb @@ -106,7 +106,7 @@ { "data": { "text/plain": [ - "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" + "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" ] }, "execution_count": 2, @@ -325,9 +325,9 @@ "Progress bar: False\n", "Resume: False\n", "\n", - "chr3 2.979121685028076\n", - "chr2 6.626891374588013\n", - "chr1 10.355815410614014\n" + "chr3 3.016324281692505\n", + "chr2 6.676640510559082\n", + "chr1 10.41373872756958\n" ] } ], @@ -766,7 +766,13 @@ { "data": { "text/plain": [ - "torch.Size([64, 128])" + "array([[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", + " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", + " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", + " ...,\n", + " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", + " [0. , 0. , 0. , ..., 0.5, 0.5, 0.5],\n", + " [0. , 0. , 0. , ..., 0.5, 0.5, 1. ]], dtype=float32)" ] }, "execution_count": 10, @@ -776,7 +782,7 @@ ], "source": [ "#np.unique(full_dataset._metadata_df['split'], return_counts=True)\n", - "y_true.squeeze().shape" + "y_true.squeeze().detach().numpy()" ] }, { @@ -809,7 +815,7 @@ { "data": { "text/plain": [ - "Namespace(algo_log_metric='multitask_avgprec', algorithm='ERM', batch_size=64, coral_penalty_weight=None, dataset='encode-tfbs', dataset_kwargs={}, device=device(type='cuda', index=0), distinct_groups=None, download=False, eval_epoch=None, eval_loader='standard', eval_only=False, eval_splits=[], eval_transform=None, evaluate_all_splits=True, frac=1.0, group_dro_step_size=None, groupby_fields=['celltype'], irm_lambda=None, irm_penalty_anneal_iters=None, loader_kwargs={'num_workers': 1, 'pin_memory': True}, log_dir='./logs', log_every=50, loss_function='multitask_bce', lr=0.001, max_grad_norm=None, max_token_length=None, model='leopard', model_kwargs={'pretrained': False}, n_epochs=5, n_groups_per_batch=2, no_group_logging=False, optimizer='Adam', optimizer_kwargs={}, progress_bar=False, resize_scale=None, resume=False, root_dir='data', save_best=True, save_last=True, save_step=None, scheduler=None, scheduler_kwargs={}, scheduler_metric_name=None, scheduler_metric_split='val', seed=0, split_scheme='official', target_resolution=None, train_loader='standard', train_transform=None, uniform_over_groups=False, use_wandb=False, val_metric='acc_avg', val_metric_decreasing=False, weight_decay=0.01)" + "device(type='cpu')" ] }, "execution_count": 11, @@ -818,7 +824,7 @@ } ], "source": [ - "config" + "y_true.device" ] }, { @@ -842,8 +848,8 @@ "\n", "Train:\n", "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (2432,) [1 0 1 ... 1 1 0] (2432,) 0.09923777357272781 tensor(0.0992, dtype=torch.float64)\n", - "why [False False False ... False False False] (1792,) [1 1 0 ... 1 0 1] (1792,) 0.18020602071676678 tensor(0.1802, dtype=torch.float64)\n", + "why [False False False ... False False False] (2432,) [1 0 1 ... 1 1 0] (2432,) 0.09923777357272781 tensor(0.0992, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1792,) [1 1 0 ... 1 0 1] (1792,) 0.18020602071676678 tensor(0.1802, device='cuda:0', dtype=torch.float64)\n", "why [False False False False False False False False False False False False\n", " False False False False False False False False False False False False\n", " False False False False False False False False False False False True\n", @@ -942,26 +948,1095 @@ " 1 0 1 1 0 0 0 0 1 0 0 0 1 1 1 1 0 0 0 1 1 0 1 1 0 0 1 0 0 1 1 1 1 0 0 0 1\n", " 0 1 0 1 1 0 1 0 0 0 1 1 1 1 1 1 0 1 0 1 1 0 1 0 0 1 1 1 1 0 1 1 0 1 1 1 1\n", " 1 1 0 0 0 0 0 1 0 1 0 0 0 1 0 1 0 0 1 0 1 1 0 1 0 1 1 1 0 1 1 1 1 0 0 1 1\n", - " 1 1 1 0 1 1 0 1] (896,) 0.12653340353855683 tensor(0.1265, dtype=torch.float64)\n", - "why [False False False ... False False False] (1152,) [0 0 0 ... 1 1 0] (1152,) 0.15009138463477656 tensor(0.1501, dtype=torch.float64)\n", - "why [ True True True ... True True True] (1920,) [0 0 1 ... 1 0 0] (1920,) 0.13893378955027236 tensor(0.1389, dtype=torch.float64)\n" + " 1 1 1 0 1 1 0 1] (896,) 0.12653340353855683 tensor(0.1265, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1152,) [0 0 0 ... 1 1 0] (1152,) 0.15009138463477656 tensor(0.1501, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... True True True] (1920,) [0 0 1 ... 1 0 0] (1920,) 0.13893378955027236 tensor(0.1389, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [1 0 1 ... 1 1 0] (8192,) 0.13583524260280033 tensor(0.1358, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... True True True] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.07954545454545454 tensor(0.0795, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... True True True] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.11778846153846154 tensor(0.1178, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.08138020833333333 tensor(0.0814, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.196875 tensor(0.1969, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (2304,) [0 0 0 ... 0 0 0] (2304,) 0.1623263888888889 tensor(0.1623, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... True True True] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.1292724609375 tensor(0.1293, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (2304,) [0 0 0 ... 0 0 0] (2304,) 0.09678819444444445 tensor(0.0968, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.19587053571428573 tensor(0.1959, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.1015625 tensor(0.1016, device='cuda:0', dtype=torch.float64)\n", + "why [False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False True True True\n", + " True True True True True True True True True True True False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False True True\n", + " True True True True False False False False False False True True\n", + " True True True True True True True True True True True True\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False True True True True True True True\n", + " True True True True True True False False False False False False\n", + " False False False False True True True True True True True True\n", + " True True True True True True False False False False False False\n", + " False True True True True True True True True True True True\n", + " True True True True True True True False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False] (512,) [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] (512,) 0.154296875 tensor(0.1543, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.142578125 tensor(0.1426, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.1318359375 tensor(0.1318, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (2432,) [0 0 0 ... 0 0 0] (2432,) 0.09580592105263158 tensor(0.0958, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.2506510416666667 tensor(0.2507, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1152,) [0 0 0 ... 0 0 0] (1152,) 0.09809027777777778 tensor(0.0981, device='cuda:0', dtype=torch.float64)\n", + "why [ True False False ... False False False] (1024,) [0 0 0 ... 0 0 0] (1024,) 0.2353515625 tensor(0.2354, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.109375 tensor(0.1094, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.14599609375 tensor(0.1460, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... True True True] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.10107421875 tensor(0.1011, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.20454545454545456 tensor(0.2045, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1024,) [0 0 0 ... 0 0 0] (1024,) 0.099609375 tensor(0.0996, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.19810267857142858 tensor(0.1981, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.10885416666666667 tensor(0.1089, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... True True True] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.1417236328125 tensor(0.1417, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.11221590909090909 tensor(0.1122, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.13040865384615385 tensor(0.1304, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... True True True] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.09588068181818182 tensor(0.0959, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.16471354166666666 tensor(0.1647, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (2176,) [0 0 0 ... 0 0 0] (2176,) 0.14935661764705882 tensor(0.1494, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.1328125 tensor(0.1328, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.10997596153846154 tensor(0.1100, device='cuda:0', dtype=torch.float64)\n", + "why [ True False False ... False False False] (1152,) [0 0 0 ... 0 0 0] (1152,) 0.1232638888888889 tensor(0.1233, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.13616071428571427 tensor(0.1362, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... True True True] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.184375 tensor(0.1844, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (2304,) [0 0 0 ... 0 0 0] (2304,) 0.14149305555555555 tensor(0.1415, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.1380615234375 tensor(0.1381, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.1 tensor(0.1000, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.15980113636363635 tensor(0.1598, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.11160714285714286 tensor(0.1116, device='cuda:0', dtype=torch.float64)\n", + "why [False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False True True True True True True True True True\n", + " True True True True True True True True True True True True\n", + " True True True True False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " True True True True True True True True True True True True\n", + " True False True True True True True True True True True True\n", + " True True True True False False False False False False True True\n", + " True True True True True True True True True True True True\n", + " True True True True False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False True True True\n", + " True True True True True True True True True True True False\n", + " False False False False True True True True True True True True\n", + " True True True True True True True True True True False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False True\n", + " True True True True True True True True True True True True\n", + " True False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " True True True False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False True True True True True True True True True True True\n", + " True True True False False True True True True True True True\n", + " True True True True True True True True True True True True\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False True True True True True\n", + " True True True True True True True True True False False False\n", + " False False False False False False False False False False False False] (768,) [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] (768,) 0.21614583333333334 tensor(0.2161, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (2944,) [0 0 0 ... 0 0 0] (2944,) 0.1328125 tensor(0.1328, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.135498046875 tensor(0.1355, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.13385416666666666 tensor(0.1339, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.15178571428571427 tensor(0.1518, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.1203125 tensor(0.1203, device='cuda:0', dtype=torch.float64)\n", + "why [False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False True True True True True True True True\n", + " True True True True True True False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False True True True True True True True True True\n", + " True True True True True False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False True True True True\n", + " False False True True True True True True True True False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False True True True True True True True True\n", + " True True True True True True False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " True True True True True True True True True True True True\n", + " True True False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " True True True True True True True True True True True True\n", + " True False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False] (896,) [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0] (896,) 0.09040178571428571 tensor(0.0904, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.1045673076923077 tensor(0.1046, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.1239013671875 tensor(0.1239, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (2176,) [0 0 0 ... 0 0 0] (2176,) 0.13051470588235295 tensor(0.1305, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.14609375 tensor(0.1461, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.12239583333333333 tensor(0.1224, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.14118303571428573 tensor(0.1412, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.16193181818181818 tensor(0.1619, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.13916015625 tensor(0.1392, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False True True] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.09801136363636363 tensor(0.0980, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (2816,) [0 0 0 ... 0 0 0] (2816,) 0.10404829545454546 tensor(0.1040, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.0875 tensor(0.0875, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.20099431818181818 tensor(0.2010, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.13984375 tensor(0.1398, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False True True] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.1226806640625 tensor(0.1227, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.11145833333333334 tensor(0.1115, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.181640625 tensor(0.1816, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... True True True] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.11484375 tensor(0.1148, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.1658653846153846 tensor(0.1659, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.16685267857142858 tensor(0.1669, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.1483154296875 tensor(0.1483, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (1152,) [0 0 0 ... 0 0 0] (1152,) 0.1302568958818959 tensor(0.1303, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.189453125 tensor(0.1895, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.11067708333333333 tensor(0.1107, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.23468815928270043 tensor(0.2347, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... True True True] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.13385416666666666 tensor(0.1339, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.16484484726123597 tensor(0.1648, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.11263020833333333 tensor(0.1126, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.15576171875 tensor(0.1558, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.11610949612403101 tensor(0.1161, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.17734375 tensor(0.1773, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.171875 tensor(0.1719, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.1438577872555272 tensor(0.1439, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... True True True] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.12439903846153846 tensor(0.1244, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.22088068181818182 tensor(0.2209, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.15223817567567566 tensor(0.1522, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.1356534090909091 tensor(0.1357, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (2176,) [0 0 0 ... 0 0 0] (2176,) 0.15503202814868278 tensor(0.1550, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... True True True] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.15613628135565832 tensor(0.1561, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.08984375 tensor(0.0898, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.14543269230769232 tensor(0.1454, device='cuda:0', dtype=torch.float64)\n", + "why [False True True ... False False False] (1152,) [0 0 0 ... 0 0 0] (1152,) 0.1545138888888889 tensor(0.1545, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.16601859327507598 tensor(0.1660, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.09486607142857142 tensor(0.0949, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.1307907754109508 tensor(0.1308, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [ True True True True True True True True True True True True\n", + " True True True False True True True True True True True True\n", + " True True True False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False True True True\n", + " True True True True True True True True True False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False True True True True True True\n", + " False False False True True True True True True False False False\n", + " False False False False False False False False False False True True\n", + " True True True True True True True True True True False False\n", + " False False False False False False False False False False False False\n", + " False True True True True True True True True True True True\n", + " True True True True False False False False True True True True\n", + " True True True True True True True False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " True True True True True True True True True True True True\n", + " True True False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False True True True True\n", + " True True True True True True True False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False True True True True True True True\n", + " True True True True False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False True True True True True True True\n", + " True True True True True False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False True True True True True True True True True True True\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " True True True True True True True True True True True False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False True True True True True True\n", + " True True True True True False False False False False False False\n", + " False False False False False False False False] (896,) [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0] (896,) 0.18861607142857142 tensor(0.1886, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (2560,) [0 0 0 ... 0 0 0] (2560,) 0.2031711368110236 tensor(0.2032, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.12560096153846154 tensor(0.1256, device='cuda:0', dtype=torch.float64)\n", + "why [False False False False False False True True True True True True\n", + " True True True True True True True True False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False True True True True True True True True\n", + " True True True True True False False False False False False False\n", + " False False False False False True True True True True True True\n", + " True True True True True True True False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False True True True True True\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False True True True True True True True True True\n", + " True True True True True True True True False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False True True True True True True True True True\n", + " True True True True False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False True True True True True True True True\n", + " True True True True True True False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False] (768,) [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] (768,) 0.1171875 tensor(0.1172, device='cuda:0', dtype=torch.float64)\n", + "why [False False True ... False False False] (2304,) [0 0 0 ... 0 0 0] (2304,) 0.14322916666666666 tensor(0.1432, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.16135470753972053 tensor(0.1614, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.13365334378265414 tensor(0.1337, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.14312537741545892 tensor(0.1431, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.1383054595896147 tensor(0.1383, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.1884765625 tensor(0.1885, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... True True True] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.0889423076923077 tensor(0.0889, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.1392934035570018 tensor(0.1393, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False True] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.1148314123790117 tensor(0.1148, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.234375 tensor(0.2344, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... True True True] (1152,) [0 0 0 ... 0 0 0] (1152,) 0.13518363161819538 tensor(0.1352, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... True True True] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.1484375 tensor(0.1484, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... True True True] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.20069918995700245 tensor(0.2007, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... True True True] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.16518310916225415 tensor(0.1652, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.14157774390243902 tensor(0.1416, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.16829982517482517 tensor(0.1683, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... True True True] (1536,) [0 0 0 ... 0 1 1] (1536,) 0.12203414351851852 tensor(0.1220, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.14015534682080924 tensor(0.1402, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.18266864778921865 tensor(0.1827, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.15107465864301803 tensor(0.1511, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.12805550230061352 tensor(0.1281, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.1798145077383275 tensor(0.1798, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... True True True] (1536,) [0 0 0 ... 1 0 0] (1536,) 0.14846865031897927 tensor(0.1485, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.2182291666666667 tensor(0.2182, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.17879293893129775 tensor(0.1788, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.17408185325186412 tensor(0.1741, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.13385180995475113 tensor(0.1339, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1152,) [0 0 0 ... 0 0 0] (1152,) 0.1634497549019608 tensor(0.1634, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (1152,) [0 0 0 ... 0 0 0] (1152,) 0.2111472315436242 tensor(0.2111, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.17961774553571427 tensor(0.1796, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (2176,) [0 0 0 ... 0 0 0] (2176,) 0.19505408546397282 tensor(0.1951, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.17720760641838973 tensor(0.1772, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.17260322523480418 tensor(0.1726, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.19631456413210446 tensor(0.1963, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.16002286585365852 tensor(0.1600, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... True True True] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.15676843030872636 tensor(0.1568, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.19980746809032893 tensor(0.1998, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.17694871945488722 tensor(0.1769, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... True True False] (3200,) [0 0 0 ... 0 0 0] (3200,) 0.17646062940470833 tensor(0.1765, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.21987862976406533 tensor(0.2199, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (1536,) [1 0 0 ... 0 0 0] (1536,) 0.22485079470618036 tensor(0.2249, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1152,) [0 0 0 ... 0 0 0] (1152,) 0.18892249103942654 tensor(0.1889, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1024,) [0 0 0 ... 0 0 0] (1024,) 0.20539447623239437 tensor(0.2054, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... True True False] (8192,) [1 0 0 ... 0 0 0] (8192,) 0.1956759851363835 tensor(0.1957, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [ True True True ... False False False] (2560,) [1 1 1 ... 0 0 0] (2560,) 0.16270833333333334 tensor(0.1627, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... True True True] (1536,) [0 0 0 ... 1 0 0] (1536,) 0.28461934747103557 tensor(0.2846, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.2885416666666667 tensor(0.2885, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.24493883087633087 tensor(0.2449, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1152,) [0 0 0 ... 0 0 0] (1152,) 0.1776813162682728 tensor(0.1777, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (8192,) [1 1 1 ... 0 0 0] (8192,) 0.22326946266948078 tensor(0.2233, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.19251085890430153 tensor(0.1925, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.22502709178398156 tensor(0.2250, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.22283878504672897 tensor(0.2228, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (1024,) [0 0 0 ... 0 0 0] (1024,) 0.2044723429144385 tensor(0.2045, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (2304,) [0 0 0 ... 0 0 0] (2304,) 0.1890666335978836 tensor(0.1891, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.1970471833881579 tensor(0.1970, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.17814201811043567 tensor(0.1781, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.22176106178589622 tensor(0.2218, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.15586979984301413 tensor(0.1559, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.19933712121212122 tensor(0.1993, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.27441314553990614 tensor(0.2744, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.20748284786370724 tensor(0.2075, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.1627858889528193 tensor(0.1628, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (1152,) [0 0 0 ... 0 0 0] (1152,) 0.3289409447955064 tensor(0.3289, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (1408,) [0 1 0 ... 0 0 0] (1408,) 0.25750782574670666 tensor(0.2575, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (2176,) [0 0 0 ... 0 0 0] (2176,) 0.2380265050832091 tensor(0.2380, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.20203645462301223 tensor(0.2020, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.2304055108248235 tensor(0.2304, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [ True True True ... False False False] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.18352952167414052 tensor(0.1835, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.391332129896404 tensor(0.3913, device='cuda:0', dtype=torch.float64)\n", + "why [False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False True\n", + " True True True True True True True True True True False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False True True True True True True True True\n", + " True True True True True True True True True True True True\n", + " True True True False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " True True True True True True True True True True True False\n", + " False True True True True True True True True True True True\n", + " True False True True True True True True True True True True\n", + " True False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False True True True True True True\n", + " True True True True True True False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False True True True True\n", + " True True True True True True True True False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False True True True True True True True True\n", + " True True False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " True True True True True True True True True True True True\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " True True True True True True True True True False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False] (896,) [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 1 1 1 1 1 1 1 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 1 1 1 0 1 1 0 0 0 0 0 1 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1\n", + " 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1\n", + " 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0] (896,) 0.43876971003366205 tensor(0.4388, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.26807482215447154 tensor(0.2681, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (2176,) [0 0 0 ... 0 0 0] (2176,) 0.2458394306739895 tensor(0.2458, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.2842815311314583 tensor(0.2843, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (1152,) [0 0 0 ... 0 0 0] (1152,) 0.24575731426692965 tensor(0.2458, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.276717519724741 tensor(0.2767, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.38168526600954644 tensor(0.3817, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True True True True True True True True True True\n", + " True True True True True True False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False True True True True True True\n", + " True True True True True True True True True True True False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False True True True True True\n", + " True True True True True True True True False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False True True True True True True True True True True\n", + " True True True True False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False True True True True True True True True\n", + " True True True True True True True True True True False False\n", + " False False False False False True True True True True True True\n", + " True True True True True True True True True True True False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False True True True\n", + " True True True True True True True True True True False False\n", + " False True True True True True True True True True True True\n", + " True True False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False True True\n", + " True True True True True True True True True True True True\n", + " True True True True True False False False False False False False\n", + " False False False False True True True True True True True True\n", + " True True True True False False True True True True True True\n", + " True True True True True True True True True True True True\n", + " True False False False False True True True True True True True\n", + " True True True True True True False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " True True True True True True True True True True True True\n", + " True True True True True True True True False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False True True True True True True True True\n", + " True True True True True True True True True True False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False] (896,) [0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 0 1 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 1\n", + " 1 1 1 1 1 1 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 1 0 1 1 1 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0] (896,) 0.3275530937683716 tensor(0.3276, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (2304,) [0 0 0 ... 0 0 0] (2304,) 0.24250047241118666 tensor(0.2425, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.27537596564595973 tensor(0.2754, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (2304,) [0 0 0 ... 0 0 0] (2304,) 0.339521139314602 tensor(0.3395, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.28316756119010217 tensor(0.2832, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1024,) [0 1 0 ... 0 0 0] (1024,) 0.30224860634648365 tensor(0.3022, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.23057474330872174 tensor(0.2306, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... True True True] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.22791799898259513 tensor(0.2279, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.27437629915291323 tensor(0.2744, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (2176,) [0 0 0 ... 0 0 0] (2176,) 0.21319969405140976 tensor(0.2132, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.3474399687036469 tensor(0.3474, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... True True True] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.19751082251082253 tensor(0.1975, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1152,) [0 0 0 ... 0 0 0] (1152,) 0.3353790123844628 tensor(0.3354, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... True True True] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.24501893939393937 tensor(0.2450, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... True True True] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.2624466475767001 tensor(0.2624, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (1408,) [1 1 1 ... 0 0 0] (1408,) 0.22450973341004987 tensor(0.2245, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... True False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.27231664754255114 tensor(0.2723, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1024,) [0 0 0 ... 0 0 0] (1024,) 0.3152901785714286 tensor(0.3153, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... True True True] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.35922695360195356 tensor(0.3592, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... True True True] (2432,) [0 0 0 ... 0 0 0] (2432,) 0.26736473289421736 tensor(0.2674, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [1 1 1 ... 0 0 0] (8192,) 0.28538833123099405 tensor(0.2854, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.245172509039775 tensor(0.2452, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.24340502699055327 tensor(0.2434, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.28707033026885964 tensor(0.2871, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (1024,) [0 0 0 ... 0 0 0] (1024,) 0.2957705135233918 tensor(0.2958, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.2895262781476896 tensor(0.2895, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.2656280862586716 tensor(0.2656, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... True True False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.19391790985177615 tensor(0.1939, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... True True True] (1152,) [0 0 0 ... 1 1 1] (1152,) 0.39839248075956224 tensor(0.3984, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.3400271739130435 tensor(0.3400, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.26218694096601075 tensor(0.2622, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (2304,) [0 0 0 ... 0 0 0] (2304,) 0.25949223766281415 tensor(0.2595, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.2731843170244799 tensor(0.2732, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (2432,) [0 0 0 ... 0 0 0] (2432,) 0.23153263758670284 tensor(0.2315, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.3294548915822105 tensor(0.3295, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1024,) [0 0 0 ... 0 0 0] (1024,) 0.50768331438611 tensor(0.5077, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.17941607556456285 tensor(0.1794, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.4005733735380117 tensor(0.4006, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.32525391000796444 tensor(0.3253, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.28327316031926486 tensor(0.2833, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (2688,) [0 0 0 ... 0 0 0] (2688,) 0.2455340291329215 tensor(0.2455, device='cuda:0', dtype=torch.float64)\n", + "why [False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False True True True True True True True\n", + " True True True True False False False False False False False False\n", + " False False False False False False False False False False True True\n", + " True True True True True True True True True True True True\n", + " True True False False False False False False False False False False\n", + " False False True True True True True True True True True True\n", + " True True False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False True True True True False False False False\n", + " False False False False False False False False False False False False\n", + " True True True True True True True True True True True True\n", + " True True True True True True True True True True True True\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False True True True True True True True True True True True\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False True True True\n", + " True True True True True True True True False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False True True True True True True True True\n", + " True True True False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False True True True True True True True True True True\n", + " True True False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False] (896,) [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1\n", + " 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 0 0 0\n", + " 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 0 0 0 1 1 1 1 1\n", + " 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0] (896,) 0.36439732142857145 tensor(0.3644, device='cuda:0', dtype=torch.float64)\n", + "why [False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False True True True True True False False False True\n", + " True True True True True True True True True True True True\n", + " True True True True True True True True True True True False\n", + " False False False False True True True True True True True True\n", + " True True True False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False True True True True True True True True\n", + " True True True True True True False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " True True True True True True True True True True True True\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False True True True True True True True True True True\n", + " True True True True False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False True True True True True True True\n", + " True True True True True True True False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False False False True True True True\n", + " True True True True True True True True True True True True\n", + " True True True True True True False False False False False False\n", + " False False False False False False False False False False False False\n", + " False False False False False False True True True True True True\n", + " True True True True True True True True False False False False\n", + " False False False False False False False False False False False False] (768,) [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1\n", + " 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", + " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] (768,) 0.36334134615384617 tensor(0.3633, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.36328125 tensor(0.3633, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.3073375105806347 tensor(0.3073, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.42102430988608963 tensor(0.4210, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.39475473771436803 tensor(0.3948, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.3678160635096611 tensor(0.3678, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.19351388184584178 tensor(0.1935, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.24591191813804175 tensor(0.2459, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.3096451568959731 tensor(0.3096, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.24075629195519133 tensor(0.2408, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.17249526515151514 tensor(0.1725, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.2863095238095238 tensor(0.2863, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.4012790080941676 tensor(0.4013, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.32623064828253506 tensor(0.3262, device='cuda:0', dtype=torch.float64)\n", + "why [ True True True ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.29373969403168476 tensor(0.2937, device='cuda:0', dtype=torch.float64)\n", + "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", + "why [False False False ... False False False] (2176,) [0 0 0 ... 0 0 0] (2176,) 0.3421500286608995 tensor(0.3422, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... True True True] (1664,) [0 0 0 ... 1 1 0] (1664,) 0.22848216513818703 tensor(0.2285, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.21294610507246378 tensor(0.2129, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.4324312010246706 tensor(0.4324, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.42839099459862173 tensor(0.4284, device='cuda:0', dtype=torch.float64)\n", + "why [False False False ... True True True] (8192,) [0 0 0 ... 1 1 0] (8192,) 0.3411826173375903 tensor(0.3412, device='cuda:0', dtype=torch.float64)\n" ] }, { - "ename": "RuntimeError", - "evalue": "All input tensors must be on the same device. Received cpu and cuda:0", + "ename": "KeyboardInterrupt", + "evalue": "", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0mbest_val_metric\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 25\u001b[0;31m train(\n\u001b[0m\u001b[1;32m 26\u001b[0m \u001b[0malgorithm\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0malgorithm\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0mdatasets\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdatasets\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/wilds/examples/train.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(algorithm, datasets, general_logger, config, epoch_offset, best_val_metric)\u001b[0m\n\u001b[1;32m 88\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 89\u001b[0m \u001b[0;31m# First run training\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 90\u001b[0;31m \u001b[0mrun_epoch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0malgorithm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdatasets\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'train'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgeneral_logger\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepoch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 91\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[0;31m# Then run val\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m~/wilds/examples/train.py\u001b[0m in \u001b[0;36mrun_epoch\u001b[0;34m(algorithm, dataset, general_logger, epoch, config, train)\u001b[0m\n\u001b[1;32m 41\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0miterator\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 43\u001b[0;31m \u001b[0mbatch_results\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0malgorithm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 44\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[0mbatch_results\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0malgorithm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mevaluate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/wilds/examples/algorithms/single_model_algorithm.py\u001b[0m in \u001b[0;36mupdate\u001b[0;34m(self, batch)\u001b[0m\n\u001b[1;32m 103\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_update\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 104\u001b[0m \u001b[0;31m# log results\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 105\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate_log\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 106\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msanitize_dict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 107\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/wilds/examples/algorithms/group_algorithm.py\u001b[0m in \u001b[0;36mupdate_log\u001b[0;34m(self, results)\u001b[0m\n\u001b[1;32m 47\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mm\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlogged_metrics\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mno_group_logging\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 49\u001b[0;31m group_metrics, group_counts, worst_group_metric = m.compute_group_wise(\n\u001b[0m\u001b[1;32m 50\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'y_pred'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'y_true'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/wilds/wilds/common/metrics/metric.py\u001b[0m in \u001b[0;36mcompute_group_wise\u001b[0;34m(self, y_pred, y_true, g, n_groups, return_dict)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mresults\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mdict\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mDictionary\u001b[0m \u001b[0mof\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 113\u001b[0m \"\"\"\n\u001b[0;32m--> 114\u001b[0;31m \u001b[0mgroup_metrics\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgroup_counts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mworst_group_metric\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compute_group_wise\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_pred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_true\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_groups\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 115\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mreturn_dict\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/wilds/wilds/common/metrics/metric.py\u001b[0m in \u001b[0;36m_compute_group_wise\u001b[0;34m(self, y_pred, y_true, g, n_groups)\u001b[0m\n\u001b[1;32m 134\u001b[0m \u001b[0my_pred\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mg\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mgroup_idx\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 135\u001b[0m y_true[g == group_idx]))\n\u001b[0;32m--> 136\u001b[0;31m \u001b[0mgroup_metrics\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstack\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgroup_metrics\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 137\u001b[0m \u001b[0mworst_group_metric\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mworst\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgroup_metrics\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mgroup_counts\u001b[0m\u001b[0;34m>\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 138\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mRuntimeError\u001b[0m: All input tensors must be on the same device. Received cpu and cuda:0" + "\u001b[0;32m~/wilds/examples/algorithms/single_model_algorithm.py\u001b[0m in \u001b[0;36mupdate\u001b[0;34m(self, batch)\u001b[0m\n\u001b[1;32m 101\u001b[0m \u001b[0;31m# process batch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 102\u001b[0m \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprocess_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 103\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_update\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 104\u001b[0m \u001b[0;31m# log results\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 105\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate_log\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/wilds/examples/algorithms/single_model_algorithm.py\u001b[0m in \u001b[0;36m_update\u001b[0;34m(self, results)\u001b[0m\n\u001b[1;32m 120\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax_grad_norm\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 121\u001b[0m \u001b[0mclip_grad_norm_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparameters\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax_grad_norm\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 122\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 123\u001b[0m self.step_schedulers(\n\u001b[1;32m 124\u001b[0m \u001b[0mis_epoch\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/site-packages/torch/autograd/grad_mode.py\u001b[0m in \u001b[0;36mdecorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mdecorate_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__class__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 26\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 27\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mcast\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mF\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdecorate_context\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/site-packages/torch/optim/adam.py\u001b[0m in \u001b[0;36mstep\u001b[0;34m(self, closure)\u001b[0m\n\u001b[1;32m 106\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 107\u001b[0m \u001b[0mbeta1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbeta2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgroup\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'betas'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 108\u001b[0;31m F.adam(params_with_grad,\n\u001b[0m\u001b[1;32m 109\u001b[0m \u001b[0mgrads\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 110\u001b[0m \u001b[0mexp_avgs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/site-packages/torch/optim/functional.py\u001b[0m in \u001b[0;36madam\u001b[0;34m(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, beta1, beta2, lr, weight_decay, eps)\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[0mdenom\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mmax_exp_avg_sq\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqrt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mmath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqrt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbias_correction2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0meps\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 93\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 94\u001b[0;31m \u001b[0mdenom\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mexp_avg_sq\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqrt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mmath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqrt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbias_correction2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0meps\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 95\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 96\u001b[0m \u001b[0mstep_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlr\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mbias_correction1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], diff --git a/wilds/common/metrics/all_metrics.py b/wilds/common/metrics/all_metrics.py index cc87ee5f..c9db2574 100644 --- a/wilds/common/metrics/all_metrics.py +++ b/wilds/common/metrics/all_metrics.py @@ -150,8 +150,7 @@ def _compute(self, y_pred, y_true): ypr, average=self.average ) - to_ret = torch.tensor(score)#.to(flattened_y_pred.device) - print("why ", ytr, ytr.shape, ypr, ypr.shape, score, to_ret) + to_ret = torch.tensor(score).to(y_pred.device) return to_ret def worst(self, metrics): From 9ee223c129a43d10d9c2b086086a28e644689275 Mon Sep 17 00:00:00 2001 From: aikanor Date: Sat, 20 Mar 2021 13:04:33 -0700 Subject: [PATCH 079/244] rebase --- examples/sbox_run_expt.ipynb | 20 ++++++++------------ wilds/common/metrics/all_metrics.py | 1 - wilds/common/metrics/metric.py | 2 +- 3 files changed, 9 insertions(+), 14 deletions(-) diff --git a/examples/sbox_run_expt.ipynb b/examples/sbox_run_expt.ipynb index 4eeaee7a..c4a25cc4 100644 --- a/examples/sbox_run_expt.ipynb +++ b/examples/sbox_run_expt.ipynb @@ -62,18 +62,9 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 14, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:root:The WILDS package is out of date. Your version is 1.0.0, while the latest version is 1.1.0.\n", - "WARNING:root:The OGB package is out of date. Your version is 1.2.4, while the latest version is 1.3.0.\n" - ] - } - ], + "outputs": [], "source": [ "import os, csv, sys\n", "os.environ['CUDA_VISIBLE_DEVICES'] = '4'\n", @@ -837,7 +828,12 @@ { "cell_type": "code", "execution_count": 12, - "metadata": {}, + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + } + }, "outputs": [ { "name": "stdout", diff --git a/wilds/common/metrics/all_metrics.py b/wilds/common/metrics/all_metrics.py index c9db2574..18a47cf4 100644 --- a/wilds/common/metrics/all_metrics.py +++ b/wilds/common/metrics/all_metrics.py @@ -79,7 +79,6 @@ def _compute_flattened(self, flattened_y_pred, flattened_y_true): average=self.average ) to_ret = torch.tensor(score).to(flattened_y_pred.device) - print("why ", ytr, ytr.shape, ypr, ypr.shape, score, to_ret) return to_ret def _compute(self, y_pred, y_true): diff --git a/wilds/common/metrics/metric.py b/wilds/common/metrics/metric.py index 2e2b0c8f..3a628886 100644 --- a/wilds/common/metrics/metric.py +++ b/wilds/common/metrics/metric.py @@ -232,7 +232,7 @@ def _compute(self, y_pred, y_true): def _compute_group_wise(self, y_pred, y_true, g, n_groups): flattened_metrics, indices = self.compute_flattened(y_pred, y_true, return_dict=False) flattened_g = g[indices] - print(flattened_metrics.shape, flattened_g.shape, (indices > 0).sum(), y_pred.shape, y_true.shape) + # print(flattened_metrics.shape, flattened_g.shape, (indices > 0).sum(), y_pred.shape, y_true.shape) group_metrics, group_counts = avg_over_groups(flattened_metrics, flattened_g, n_groups) worst_group_metric = self.worst(group_metrics[group_counts>0]) return group_metrics, group_counts, worst_group_metric From 2087901130a2e377f31904e23a8f9df69282395a Mon Sep 17 00:00:00 2001 From: aikanor Date: Sat, 20 Mar 2021 17:56:37 -0700 Subject: [PATCH 080/244] remove staging nb --- examples/sbox_run_expt.ipynb | 2158 ---------------------------------- 1 file changed, 2158 deletions(-) delete mode 100644 examples/sbox_run_expt.ipynb diff --git a/examples/sbox_run_expt.ipynb b/examples/sbox_run_expt.ipynb deleted file mode 100644 index c4a25cc4..00000000 --- a/examples/sbox_run_expt.ipynb +++ /dev/null @@ -1,2158 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# run_expt.py contents\n", - "\n", - "## 1) Preamble" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "47.42578125\n" - ] - } - ], - "source": [ - "import os, psutil; print(psutil.Process(os.getpid()).memory_info().rss / 1024 ** 2)" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "collapsed": true, - "jupyter": { - "outputs_hidden": true - } - }, - "outputs": [ - { - "ename": "NameError", - "evalue": "name 'bw' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# import pyBigWig\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;31m# %timeit bw = pyBigWig.open(\"/users/abalsubr/wilds/examples/data/encode-tfbs_v1.0/DNASE.K562.fc.signal.bigwig\")\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mget_ipython\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_line_magic\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'timeit'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"bw.values('chr1', 10000, 22800, numpy=True)\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/site-packages/IPython/core/interactiveshell.py\u001b[0m in \u001b[0;36mrun_line_magic\u001b[0;34m(self, magic_name, line, _stack_depth)\u001b[0m\n\u001b[1;32m 2334\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'local_ns'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_local_scope\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstack_depth\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2335\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbuiltin_trap\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2336\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2337\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2338\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m\u001b[0m in \u001b[0;36mtimeit\u001b[0;34m(self, line, cell, local_ns)\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/site-packages/IPython/core/magic.py\u001b[0m in \u001b[0;36m\u001b[0;34m(f, *a, **k)\u001b[0m\n\u001b[1;32m 185\u001b[0m \u001b[0;31m# but it's overkill for just that one bit of state.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 186\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mmagic_deco\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 187\u001b[0;31m \u001b[0mcall\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 188\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 189\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mcallable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/site-packages/IPython/core/magics/execution.py\u001b[0m in \u001b[0;36mtimeit\u001b[0;34m(self, line, cell, local_ns)\u001b[0m\n\u001b[1;32m 1167\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mindex\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1168\u001b[0m \u001b[0mnumber\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m10\u001b[0m \u001b[0;34m**\u001b[0m \u001b[0mindex\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1169\u001b[0;31m \u001b[0mtime_number\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtimer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtimeit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnumber\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1170\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtime_number\u001b[0m \u001b[0;34m>=\u001b[0m \u001b[0;36m0.2\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1171\u001b[0m \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/site-packages/IPython/core/magics/execution.py\u001b[0m in \u001b[0;36mtimeit\u001b[0;34m(self, number)\u001b[0m\n\u001b[1;32m 167\u001b[0m \u001b[0mgc\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdisable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 168\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 169\u001b[0;31m \u001b[0mtiming\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minner\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mit\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtimer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 170\u001b[0m \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 171\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mgcold\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m\u001b[0m in \u001b[0;36minner\u001b[0;34m(_it, _timer)\u001b[0m\n", - "\u001b[0;31mNameError\u001b[0m: name 'bw' is not defined" - ] - } - ], - "source": [ - "# import pyBigWig\n", - "# %timeit bw = pyBigWig.open(\"/users/abalsubr/wilds/examples/data/encode-tfbs_v1.0/DNASE.K562.fc.signal.bigwig\")\n", - "%timeit bw.values('chr1', 10000, 22800, numpy=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [], - "source": [ - "import os, csv, sys\n", - "os.environ['CUDA_VISIBLE_DEVICES'] = '4'\n", - "\n", - "import time\n", - "import argparse\n", - "import numpy as np, pandas as pd\n", - "import torch\n", - "import torch.nn as nn\n", - "import torchvision\n", - "import pyBigWig\n", - "from collections import defaultdict\n", - "\n", - "from wilds.common.data_loaders import get_train_loader, get_eval_loader\n", - "from wilds.common.grouper import CombinatorialGrouper\n", - "\n", - "from utils import set_seed, Logger, BatchLogger, log_config, ParseKwargs, load, initialize_wandb, log_group_data, parse_bool\n", - "from train import train, evaluate\n", - "from algorithms.initializer import initialize_algorithm\n", - "from transforms import initialize_transform\n", - "from configs.utils import populate_defaults\n", - "import configs.supported as supported" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "_StoreAction(option_strings=['--resume'], dest='resume', nargs='?', const=True, default=False, type=, choices=None, help=None, metavar=None)" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "''' set default hyperparams in default_hyperparams.py '''\n", - "parser = argparse.ArgumentParser()\n", - "\n", - "# Required arguments\n", - "parser.add_argument('-d', '--dataset', choices=supported.datasets, required=True)\n", - "parser.add_argument('--algorithm', required=True, choices=supported.algorithms)\n", - "parser.add_argument('--root_dir', required=True,\n", - " help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).')\n", - "\n", - "# Dataset\n", - "parser.add_argument('--split_scheme', help='Identifies how the train/val/test split is constructed. Choices are dataset-specific.')\n", - "parser.add_argument('--dataset_kwargs', nargs='*', action=ParseKwargs, default={})\n", - "parser.add_argument('--download', default=False, type=parse_bool, const=True, nargs='?',\n", - " help='If true, tries to downloads the dataset if it does not exist in root_dir.')\n", - "parser.add_argument('--frac', type=float, default=1.0,\n", - " help='Convenience parameter that scales all dataset splits down to the specified fraction, for development purposes.')\n", - "\n", - "# Loaders\n", - "parser.add_argument('--loader_kwargs', nargs='*', action=ParseKwargs, default={})\n", - "parser.add_argument('--train_loader', choices=['standard', 'group'])\n", - "parser.add_argument('--uniform_over_groups', type=parse_bool, const=True, nargs='?')\n", - "parser.add_argument('--distinct_groups', type=parse_bool, const=True, nargs='?')\n", - "parser.add_argument('--n_groups_per_batch', type=int)\n", - "parser.add_argument('--batch_size', type=int)\n", - "parser.add_argument('--eval_loader', choices=['standard'], default='standard')\n", - "\n", - "# Model\n", - "parser.add_argument('--model', choices=supported.models)\n", - "parser.add_argument('--model_kwargs', nargs='*', action=ParseKwargs, default={},\n", - " help='keyword arguments for model initialization passed as key1=value1 key2=value2')\n", - "\n", - "# Transforms\n", - "parser.add_argument('--train_transform', choices=supported.transforms)\n", - "parser.add_argument('--eval_transform', choices=supported.transforms)\n", - "parser.add_argument('--target_resolution', nargs='+', type=int, help='target resolution. for example --target_resolution 224 224 for standard resnet.')\n", - "parser.add_argument('--resize_scale', type=float)\n", - "parser.add_argument('--max_token_length', type=int)\n", - "\n", - "# Objective\n", - "parser.add_argument('--loss_function', choices = supported.losses)\n", - "\n", - "# Algorithm\n", - "parser.add_argument('--groupby_fields', nargs='+')\n", - "parser.add_argument('--group_dro_step_size', type=float)\n", - "parser.add_argument('--coral_penalty_weight', type=float)\n", - "parser.add_argument('--irm_lambda', type=float)\n", - "parser.add_argument('--irm_penalty_anneal_iters', type=int)\n", - "parser.add_argument('--algo_log_metric')\n", - "\n", - "# Model selection\n", - "parser.add_argument('--val_metric')\n", - "parser.add_argument('--val_metric_decreasing', type=parse_bool, const=True, nargs='?')\n", - "\n", - "# Optimization\n", - "parser.add_argument('--n_epochs', type=int)\n", - "parser.add_argument('--optimizer', choices=supported.optimizers)\n", - "parser.add_argument('--lr', type=float)\n", - "parser.add_argument('--weight_decay', type=float)\n", - "parser.add_argument('--max_grad_norm', type=float)\n", - "parser.add_argument('--optimizer_kwargs', nargs='*', action=ParseKwargs, default={})\n", - "\n", - "# Scheduler\n", - "parser.add_argument('--scheduler', choices=supported.schedulers)\n", - "parser.add_argument('--scheduler_kwargs', nargs='*', action=ParseKwargs, default={})\n", - "parser.add_argument('--scheduler_metric_split', choices=['train', 'val'], default='val')\n", - "parser.add_argument('--scheduler_metric_name')\n", - "\n", - "# Evaluation\n", - "parser.add_argument('--evaluate_all_splits', type=parse_bool, const=True, nargs='?', default=True)\n", - "parser.add_argument('--eval_splits', nargs='+', default=[])\n", - "parser.add_argument('--eval_only', type=parse_bool, const=True, nargs='?', default=False)\n", - "parser.add_argument('--eval_epoch', default=None, type=int)\n", - "\n", - "# Misc\n", - "parser.add_argument('--device', type=int, default=0)\n", - "parser.add_argument('--seed', type=int, default=0)\n", - "parser.add_argument('--log_dir', default='./logs')\n", - "parser.add_argument('--log_every', default=50, type=int)\n", - "parser.add_argument('--save_step', type=int)\n", - "parser.add_argument('--save_best', type=parse_bool, const=True, nargs='?', default=True)\n", - "parser.add_argument('--save_last', type=parse_bool, const=True, nargs='?', default=True)\n", - "parser.add_argument('--no_group_logging', type=parse_bool, const=True, nargs='?')\n", - "parser.add_argument('--use_wandb', type=parse_bool, const=True, nargs='?', default=False)\n", - "parser.add_argument('--progress_bar', type=parse_bool, const=True, nargs='?', default=False)\n", - "parser.add_argument('--resume', type=parse_bool, const=True, nargs='?', default=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "argstr_camelyon = \"--dataset camelyon17 --algorithm ERM --root_dir data\"\n", - "config_camelyon = parser.parse_args(argstr_camelyon.split())\n", - "config_camelyon = populate_defaults(config_camelyon)\n", - "\n", - "argstr_bdd100k = \"--dataset bdd100k --algorithm ERM --root_dir data\"\n", - "config_bdd100k = parser.parse_args(argstr_bdd100k.split())\n", - "config_bdd100k = populate_defaults(config_bdd100k)\n", - "\n", - "argstr_encode = \"--dataset encode-tfbs --algorithm ERM --root_dir data\"\n", - "config_encode = parser.parse_args(argstr_encode.split())\n", - "config_encode = populate_defaults(config_encode)\n", - "\n", - "config = config_camelyon\n", - "config = config_encode\n", - "# config = config_bdd100k\n" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Namespace(algo_log_metric=None, algorithm='ERM', batch_size=None, coral_penalty_weight=None, dataset='encode-tfbs', dataset_kwargs={}, device=0, distinct_groups=None, download=False, eval_epoch=None, eval_loader='standard', eval_only=False, eval_splits=[], eval_transform=None, evaluate_all_splits=True, frac=1.0, group_dro_step_size=None, groupby_fields=None, irm_lambda=None, irm_penalty_anneal_iters=None, loader_kwargs={'num_workers': 1, 'pin_memory': True}, log_dir='./logs', log_every=50, loss_function=None, lr=None, max_grad_norm=None, max_token_length=None, model=None, model_kwargs={'pretrained': False}, n_epochs=None, n_groups_per_batch=None, no_group_logging=None, optimizer=None, optimizer_kwargs={'momentum': 0.9}, progress_bar=False, resize_scale=None, resume=False, root_dir='data', save_best=True, save_last=True, save_step=None, scheduler=None, scheduler_kwargs={}, scheduler_metric_name=None, scheduler_metric_split='val', seed=0, split_scheme=None, target_resolution=None, train_loader=None, train_transform=None, uniform_over_groups=None, use_wandb=False, val_metric=None, val_metric_decreasing=None, weight_decay=None)" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "argstr_camelyon = \"--dataset camelyon17 --algorithm ERM --root_dir data\"\n", - "# argstr_camelyon = \"--dataset civilcomments --algorithm ERM --root_dir data\"\n", - "config_camelyon = parser.parse_args(argstr_camelyon.split())\n", - "\n", - "argstr_encode = \"--dataset encode-tfbs --algorithm ERM --root_dir data\"\n", - "config_encode = parser.parse_args(argstr_encode.split())\n", - "config_encode" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "config.optimizer_kwargs = {}" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Dataset: encode-tfbs\n", - "Algorithm: ERM\n", - "Root dir: data\n", - "Split scheme: official\n", - "Dataset kwargs: {}\n", - "Download: False\n", - "Frac: 1.0\n", - "Loader kwargs: {'num_workers': 1, 'pin_memory': True}\n", - "Train loader: standard\n", - "Uniform over groups: False\n", - "Distinct groups: None\n", - "N groups per batch: 2\n", - "Batch size: 64\n", - "Eval loader: standard\n", - "Model: leopard\n", - "Model kwargs: {'pretrained': False}\n", - "Train transform: None\n", - "Eval transform: None\n", - "Target resolution: None\n", - "Resize scale: None\n", - "Max token length: None\n", - "Loss function: multitask_bce\n", - "Groupby fields: ['celltype']\n", - "Group dro step size: None\n", - "Coral penalty weight: None\n", - "Irm lambda: None\n", - "Irm penalty anneal iters: None\n", - "Algo log metric: multitask_avgprec\n", - "Val metric: acc_avg\n", - "Val metric decreasing: False\n", - "N epochs: 5\n", - "Optimizer: Adam\n", - "Lr: 0.001\n", - "Weight decay: 0.01\n", - "Max grad norm: None\n", - "Optimizer kwargs: {}\n", - "Scheduler: None\n", - "Scheduler kwargs: {}\n", - "Scheduler metric split: val\n", - "Scheduler metric name: None\n", - "Evaluate all splits: True\n", - "Eval splits: []\n", - "Eval only: False\n", - "Eval epoch: None\n", - "Device: cuda:0\n", - "Seed: 0\n", - "Log dir: ./logs\n", - "Log every: 50\n", - "Save step: None\n", - "Save best: True\n", - "Save last: True\n", - "No group logging: False\n", - "Use wandb: False\n", - "Progress bar: False\n", - "Resume: False\n", - "\n", - "chr3 3.016324281692505\n", - "chr2 6.676640510559082\n", - "chr1 10.41373872756958\n" - ] - } - ], - "source": [ - "# set device\n", - "config.device = torch.device(\"cuda:\" + str(config.device)) if torch.cuda.is_available() else torch.device(\"cpu\")\n", - "\n", - "## Initialize logs\n", - "if os.path.exists(config.log_dir) and config.resume:\n", - " resume=True\n", - " mode='a'\n", - "elif os.path.exists(config.log_dir) and config.eval_only:\n", - " resume=False\n", - " mode='a'\n", - "else:\n", - " resume=False\n", - " mode='w'\n", - "\n", - "if not os.path.exists(config.log_dir):\n", - " os.makedirs(config.log_dir)\n", - "logger = Logger(os.path.join(config.log_dir, 'log.txt'), mode)\n", - "\n", - "# Record config\n", - "log_config(config, logger)\n", - "\n", - "# Set random seed\n", - "set_seed(config.seed)\n", - "\n", - "# Data\n", - "full_dataset = supported.datasets[config.dataset](\n", - " root_dir=config.root_dir,\n", - " download=config.download,\n", - " split_scheme=config.split_scheme,\n", - " **config.dataset_kwargs)\n", - "\n", - "# To implement data augmentation (i.e., have different transforms\n", - "# at training time vs. test time), modify these two lines:\n", - "train_transform = initialize_transform(\n", - " transform_name=config.train_transform,\n", - " config=config,\n", - " dataset=full_dataset)\n", - "eval_transform = initialize_transform(\n", - " transform_name=config.eval_transform,\n", - " config=config,\n", - " dataset=full_dataset)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## 2) Initialize dataset object (trial version)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "collapsed": true, - "jupyter": { - "outputs_hidden": true, - "source_hidden": true - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "chr3 3.0039219856262207\n", - "chr4 5.89985990524292\n", - "chr5 8.640583038330078\n", - "chr6 11.237342596054077\n", - "chr7 13.666043519973755\n", - "chr10 15.858035326004028\n", - "chr12 17.94972252845764\n", - "chr13 19.689449071884155\n", - "chr14 21.30842876434326\n", - "chr15 22.856398582458496\n" - ] - }, - { - "ename": "KeyboardInterrupt", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[0m_seq_bp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 57\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mchrom\u001b[0m \u001b[0;32min\u001b[0m \u001b[0m_all_chroms\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 58\u001b[0;31m \u001b[0m_seq_bp\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mchrom\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mseq_arr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mchrom\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 59\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mchrom\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mitime\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/site-packages/numpy/lib/npyio.py\u001b[0m in \u001b[0;36m__getitem__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 252\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mmagic\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mformat\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mMAGIC_PREFIX\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 253\u001b[0m \u001b[0mbytes\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzip\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 254\u001b[0;31m return format.read_array(bytes,\n\u001b[0m\u001b[1;32m 255\u001b[0m \u001b[0mallow_pickle\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mallow_pickle\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 256\u001b[0m pickle_kwargs=self.pickle_kwargs)\n", - "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/site-packages/numpy/lib/format.py\u001b[0m in \u001b[0;36mread_array\u001b[0;34m(fp, allow_pickle, pickle_kwargs)\u001b[0m\n\u001b[1;32m 773\u001b[0m \u001b[0mread_count\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmax_read_count\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcount\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 774\u001b[0m \u001b[0mread_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mread_count\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitemsize\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 775\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_read_bytes\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mread_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"array data\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 776\u001b[0m array[i:i+read_count] = numpy.frombuffer(data, dtype=dtype,\n\u001b[1;32m 777\u001b[0m count=read_count)\n", - "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/site-packages/numpy/lib/format.py\u001b[0m in \u001b[0;36m_read_bytes\u001b[0;34m(fp, size, error_template)\u001b[0m\n\u001b[1;32m 902\u001b[0m \u001b[0;31m# done about that. note that regular files can't be non-blocking\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 903\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 904\u001b[0;31m \u001b[0mr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msize\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 905\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mr\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 906\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mr\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0msize\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/zipfile.py\u001b[0m in \u001b[0;36mread\u001b[0;34m(self, n)\u001b[0m\n\u001b[1;32m 938\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_offset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 939\u001b[0m \u001b[0;32mwhile\u001b[0m \u001b[0mn\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m0\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_eof\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 940\u001b[0;31m \u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_read1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 941\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mn\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 942\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_readbuffer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/zipfile.py\u001b[0m in \u001b[0;36m_read1\u001b[0;34m(self, n)\u001b[0m\n\u001b[1;32m 1028\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_left\u001b[0m \u001b[0;34m<=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1029\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_eof\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1030\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_update_crc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1031\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1032\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/zipfile.py\u001b[0m in \u001b[0;36m_update_crc\u001b[0;34m(self, newdata)\u001b[0m\n\u001b[1;32m 953\u001b[0m \u001b[0;31m# No need to compute the CRC if we don't have a reference value\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 954\u001b[0m \u001b[0;32mreturn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 955\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_running_crc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcrc32\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnewdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_running_crc\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 956\u001b[0m \u001b[0;31m# Check the CRC if we're at the end of the file\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 957\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_eof\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_running_crc\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_expected_crc\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mKeyboardInterrupt\u001b[0m: " - ] - } - ], - "source": [ - "import os, time\n", - "import torch\n", - "import pandas as pd\n", - "import numpy as np\n", - "from wilds.datasets.wilds_dataset import WILDSDataset\n", - "from wilds.common.grouper import CombinatorialGrouper\n", - "from wilds.common.metrics.all_metrics import Accuracy\n", - "\n", - "root_dir='data'\n", - "download=False\n", - "split_scheme='official'\n", - "\n", - "itime = time.time()\n", - "_dataset_name = 'encode-tfbs'\n", - "_version = '1.0'\n", - "_download_url = 'https://worksheets.codalab.org/rest/bundles/0x8b3255e21e164cd98d3aeec09cd0bc26/contents/blob/'\n", - "_data_dir = 'data/encode-tfbs_v1.0/'\n", - "_y_size = 1\n", - "_n_classes = 2\n", - "\n", - "_train_chroms = ['chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr10', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX']\n", - "_val_chroms = ['chr2', 'chr9', 'chr11']\n", - "_test_chroms = ['chr1', 'chr8', 'chr21']\n", - "_transcription_factor = 'MAX'\n", - "_train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']\n", - "_val_celltype = ['A549']\n", - "_test_celltype = ['GM12878']\n", - "_all_chroms = _train_chroms + _val_chroms + _test_chroms\n", - "_all_celltypes = _train_celltypes + _val_celltype + _test_celltype\n", - "\n", - "_metadata_map = {}\n", - "_metadata_map['chr'] = _all_chroms\n", - "_metadata_map['celltype'] = _all_celltypes\n", - "\n", - "# Get the splits\n", - "if split_scheme=='official':\n", - " split_scheme = 'standard'\n", - "\n", - "_split_scheme = split_scheme\n", - "_split_dict = {\n", - " 'train': 0,\n", - " 'id_val': 1,\n", - " 'test': 2,\n", - " 'val': 3\n", - "}\n", - "_split_names = {\n", - " 'train': 'Train',\n", - " 'id_val': 'Validation (ID)',\n", - " 'test': 'Test',\n", - " 'val': 'Validation (OOD)',\n", - "}\n", - "\n", - "# Load sequence and DNase features\n", - "sequence_filename = os.path.join(_data_dir, 'sequence.npz')\n", - "seq_arr = np.load(sequence_filename)\n", - "_seq_bp = {}\n", - "for chrom in _all_chroms:\n", - " _seq_bp[chrom] = seq_arr[chrom]\n", - " print(chrom, time.time() - itime)\n", - "\n", - "_dnase_allcelltypes = {}\n", - "ct = 'avg'\n", - "dnase_avg_bw_path = os.path.join(_data_dir, 'Leopard_dnase/{}.bigwig'.format(ct))\n", - "_dnase_allcelltypes[ct] = pyBigWig.open(dnase_avg_bw_path)\n", - "for ct in _all_celltypes:\n", - " \"\"\"\n", - " dnase_filename = os.path.join(self._data_dir, '{}_dnase.npz'.format(ct))\n", - " dnase_npz_contents = np.load(dnase_filename)\n", - " self._dnase_allcelltypes[ct] = {}\n", - " for chrom in self._all_chroms: #self._seq_bp:\n", - " self._dnase_allcelltypes[ct][chrom] = dnase_npz_contents[chrom]\n", - " \"\"\"\n", - " dnase_bw_path = os.path.join(_data_dir, 'Leopard_dnase/{}.bigwig'.format(ct))\n", - " _dnase_allcelltypes[ct] = pyBigWig.open(dnase_bw_path)\n", - " print(ct, time.time() - itime)\n", - "\n", - "_metadata_df = pd.read_csv(\n", - " _data_dir + 'labels/MAX/metadata_df.bed', sep='\\t', header=None, \n", - " index_col=None, names=['chr', 'start', 'stop', 'celltype']\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "jupyter": { - "source_hidden": true - } - }, - "outputs": [], - "source": [ - "train_regions_mask = np.isin(_metadata_df['chr'], _train_chroms)\n", - "val_regions_mask = np.isin(_metadata_df['chr'], _val_chroms)\n", - "test_regions_mask = np.isin(_metadata_df['chr'], _test_chroms)\n", - "train_celltype_mask = np.isin(_metadata_df['celltype'], _train_celltypes)\n", - "val_celltype_mask = np.isin(_metadata_df['celltype'], _val_celltype)\n", - "test_celltype_mask = np.isin(_metadata_df['celltype'], _test_celltype)\n", - "\n", - "split_array = -1*np.ones(_metadata_df.shape[0]).astype(int)\n", - "split_array[np.logical_and(train_regions_mask, train_celltype_mask)] = _split_dict['train']\n", - "split_array[np.logical_and(test_regions_mask, test_celltype_mask)] = _split_dict['test']\n", - "# Validate using validation chr, either using a designated validation cell line ('val') or a training cell line ('id_val')\n", - "split_array[np.logical_and(val_regions_mask, val_celltype_mask)] = _split_dict['val']\n", - "split_array[np.logical_and(val_regions_mask, train_celltype_mask)] = _split_dict['id_val']\n", - "\n", - "if _split_scheme=='standard':\n", - " _metadata_df.insert(len(_metadata_df.columns), 'split', split_array)\n", - "else:\n", - " raise ValueError(f'Split scheme {_split_scheme} not recognized')\n", - "\n", - "metadata_mask = (_metadata_df['split'] != -1)\n", - "_metadata_df = _metadata_df[_metadata_df['split'] != -1]\n", - "\n", - "chr_ints = _metadata_df['chr'].replace(dict( [(y, x) for x, y in enumerate(_metadata_map['chr'])] )).values\n", - "celltype_ints = _metadata_df['celltype'].replace(dict( [(y, x) for x, y in enumerate(_metadata_map['celltype'])] )).values\n", - "_split_array = _metadata_df['split'].values\n", - "\n", - "_y_array = torch.Tensor(np.load(_data_dir + 'labels/MAX/metadata_y.npy'))\n", - "_y_array = _y_array[metadata_mask]\n", - "\n", - "_metadata_array = torch.stack(\n", - " (torch.LongTensor(chr_ints), \n", - " torch.LongTensor(celltype_ints)\n", - " ),\n", - " dim=1)\n", - "_metadata_fields = ['chr', 'celltype']" - ] - }, - { - "cell_type": "code", - "execution_count": 325, - "metadata": { - "jupyter": { - "source_hidden": true - } - }, - "outputs": [], - "source": [ - "def get_random_label_vec(\n", - " metadata_df, seed_chr, seed_celltype, seed_start, output_size=128\n", - "):\n", - " \"\"\"\n", - " Given a coordinate in a celltype, gets the labels of \n", - " the `output_size` 200bp bins from that coordinate onward. \n", - " \"\"\"\n", - " itime = time.time()\n", - " \n", - " # Extract regions from this chromosome in this celltype, to get a window of labels from\n", - " # print(time.time() - itime)\n", - " # chr_msk = np.array(metadata_df['chr']) == seed_region['chr']\n", - " # print(time.time() - itime)\n", - " # ct_msk = np.array(metadata_df['celltype']) == seed_region['celltype']\n", - " # mdf = metadata_df[chr_msk & ct_msk]\n", - " seq_size = output_size*50\n", - " mdf = metadata_df.loc[\n", - " (metadata_df['chr'] == seed_chr) & \n", - " (metadata_df['celltype'] == seed_celltype) & \n", - " (metadata_df['start'] >= seed_start) & \n", - " (metadata_df['stop'] < seed_start+seq_size)\n", - " ]\n", - " print(time.time() - itime)\n", - "\n", - " # Get labels\n", - " y_label_vec = np.zeros(output_size)\n", - " y_label_vec[(mdf['start'] - seed_start) // 50] = mdf['y']\n", - " return mdf, y_label_vec" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Initialize algorithm" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Train data...\n", - " celltype = H1-hESC: n = 5314\n", - " celltype = HCT116: n = 4759\n", - " celltype = HeLa-S3: n = 4635\n", - " celltype = HepG2: n = 4459\n", - " celltype = K562: n = 5169\n", - " celltype = A549: n = 0\n", - " celltype = GM12878: n = 0\n", - "Validation (ID) data...\n", - " celltype = H1-hESC: n = 6872\n", - " celltype = HCT116: n = 6315\n", - " celltype = HeLa-S3: n = 4219\n", - " celltype = HepG2: n = 8356\n", - " celltype = K562: n = 6538\n", - " celltype = A549: n = 0\n", - " celltype = GM12878: n = 0\n", - "Test data...\n", - " celltype = H1-hESC: n = 0\n", - " celltype = HCT116: n = 0\n", - " celltype = HeLa-S3: n = 0\n", - " celltype = HepG2: n = 0\n", - " celltype = K562: n = 0\n", - " celltype = A549: n = 0\n", - " celltype = GM12878: n = 4487\n", - "Validation (OOD) data...\n", - " celltype = H1-hESC: n = 0\n", - " celltype = HCT116: n = 0\n", - " celltype = HeLa-S3: n = 0\n", - " celltype = HepG2: n = 0\n", - " celltype = K562: n = 0\n", - " celltype = A549: n = 6728\n", - " celltype = GM12878: n = 0\n", - "Dout: 128\n" - ] - } - ], - "source": [ - "# config = config_encode\n", - "\n", - "train_grouper = CombinatorialGrouper(\n", - " dataset=full_dataset,\n", - " groupby_fields=config.groupby_fields)\n", - "\n", - "datasets = defaultdict(dict)\n", - "for split in full_dataset.split_dict.keys():\n", - " if split=='train':\n", - " transform = train_transform\n", - " verbose = True\n", - " elif split == 'val':\n", - " transform = eval_transform\n", - " verbose = True\n", - " else:\n", - " transform = eval_transform\n", - " verbose = False\n", - " # Get subset\n", - " datasets[split]['dataset'] = full_dataset.get_subset(\n", - " split,\n", - " frac=config.frac,\n", - " transform=transform)\n", - "\n", - " if split == 'train':\n", - " datasets[split]['loader'] = get_train_loader(\n", - " loader=config.train_loader,\n", - " dataset=datasets[split]['dataset'],\n", - " batch_size=config.batch_size,\n", - " uniform_over_groups=config.uniform_over_groups,\n", - " grouper=train_grouper,\n", - " distinct_groups=config.distinct_groups,\n", - " n_groups_per_batch=config.n_groups_per_batch,\n", - " **config.loader_kwargs)\n", - " else:\n", - " datasets[split]['loader'] = get_eval_loader(\n", - " loader=config.eval_loader,\n", - " dataset=datasets[split]['dataset'],\n", - " grouper=train_grouper,\n", - " batch_size=config.batch_size,\n", - " **config.loader_kwargs)\n", - "\n", - " # Set fields\n", - " datasets[split]['split'] = split\n", - " datasets[split]['name'] = full_dataset.split_names[split]\n", - " datasets[split]['verbose'] = verbose\n", - " # Loggers\n", - " # Loggers\n", - " datasets[split]['eval_logger'] = BatchLogger(\n", - " os.path.join(config.log_dir, f'{split}_eval.csv'), mode=mode, use_wandb=(config.use_wandb and verbose))\n", - " datasets[split]['algo_logger'] = BatchLogger(\n", - " os.path.join(config.log_dir, f'{split}_algo.csv'), mode=mode, use_wandb=(config.use_wandb and verbose))\n", - "\n", - " if config.use_wandb:\n", - " initialize_wandb(config)\n", - "\n", - "# Logging dataset info\n", - "if config.no_group_logging and full_dataset.is_classification and full_dataset.y_size==1:\n", - " log_grouper = CombinatorialGrouper(\n", - " dataset=full_dataset,\n", - " groupby_fields=['y'])\n", - "elif config.no_group_logging:\n", - " log_grouper = None\n", - "else:\n", - " log_grouper = train_grouper\n", - "log_group_data(datasets, log_grouper, logger)\n", - "\n", - "## Initialize algorithm\n", - "algorithm = initialize_algorithm(\n", - " config=config,\n", - " datasets=datasets,\n", - " train_grouper=train_grouper)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "for batch in datasets['train']['loader']:\n", - " x, y_true, metadata = batch\n", - " break\n", - "# x = torch.transpose(x, 1, 2)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor(0.7212, device='cuda:0', grad_fn=)" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "d = algorithm.process_batch(batch)\n", - "\n", - "a = algorithm.loss.compute(d['y_pred'], d['y_true'], return_dict=False)\n", - "a" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([[0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " ...,\n", - " [0. , 0. , 0. , ..., 0. , 0. , 0. ],\n", - " [0. , 0. , 0. , ..., 0.5, 0.5, 0.5],\n", - " [0. , 0. , 0. , ..., 0.5, 0.5, 1. ]], dtype=float32)" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "#np.unique(full_dataset._metadata_df['split'], return_counts=True)\n", - "y_true.squeeze().detach().numpy()" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "ename": "NameError", - "evalue": "name 'importlib' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m#import importlib\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mimportlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;31mNameError\u001b[0m: name 'importlib' is not defined" - ] - } - ], - "source": [ - "#import importlib\n", - "importlib.reload(train)" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "device(type='cpu')" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "y_true.device" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Train" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "collapsed": true, - "jupyter": { - "outputs_hidden": true - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Epoch [0]:\n", - "\n", - "Train:\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (2432,) [1 0 1 ... 1 1 0] (2432,) 0.09923777357272781 tensor(0.0992, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1792,) [1 1 0 ... 1 0 1] (1792,) 0.18020602071676678 tensor(0.1802, device='cuda:0', dtype=torch.float64)\n", - "why [False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False True\n", - " True True True True True True True True True True True False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False True\n", - " True True True True True True True True True True True False\n", - " False True True True True True True True True True True True\n", - " True True True True True False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False True True True True True True True True True\n", - " True True True False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False True True\n", - " True True True True True True True True True True False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False True True True True True True True True True\n", - " True True True False False False False False False False False False\n", - " False False False False True True True True True True True True\n", - " True True True False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False True True True True True True True\n", - " True True True True True False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False] (896,) [1 0 1 0 0 1 1 0 1 0 0 1 0 1 1 1 0 1 1 0 1 0 1 1 0 0 1 1 1 1 1 1 0 1 0 0 0\n", - " 1 0 1 1 0 0 0 1 1 1 0 1 0 1 0 0 0 1 0 0 0 0 0 0 0 0 0 1 0 0 1 0 0 1 0 0 0\n", - " 0 1 1 1 0 1 1 0 1 0 1 0 0 1 0 0 1 1 0 0 0 0 1 1 0 0 0 0 1 0 0 1 0 0 0 0 0\n", - " 1 1 0 0 1 1 0 1 0 1 0 0 1 0 1 0 1 1 0 0 0 1 1 1 1 1 0 1 0 0 0 1 0 0 1 1 0\n", - " 1 0 0 1 0 0 0 0 1 1 1 1 0 0 0 0 0 1 0 1 0 1 1 1 1 0 1 1 0 0 1 1 1 1 1 1 0\n", - " 0 0 1 1 1 1 1 1 1 1 0 0 1 0 1 0 0 0 1 1 1 0 1 1 1 1 1 1 0 1 0 1 1 0 1 1 0\n", - " 1 0 1 1 1 1 1 0 0 1 1 1 0 0 1 1 0 0 1 0 1 0 0 1 1 1 1 0 1 1 1 1 1 0 1 1 1\n", - " 1 0 0 0 1 0 0 1 0 1 1 1 0 1 1 1 0 0 1 1 0 0 0 1 0 0 1 1 0 1 0 0 0 1 0 0 0\n", - " 1 1 0 1 0 1 1 0 0 1 0 1 0 1 1 1 1 0 1 0 1 0 1 0 0 1 1 0 1 1 0 1 1 1 1 0 0\n", - " 1 1 0 1 0 0 1 1 1 0 0 0 1 0 0 0 0 0 1 0 0 1 1 0 0 0 0 0 1 1 0 1 1 1 1 0 1\n", - " 0 0 0 1 1 0 0 0 0 1 1 1 0 1 1 1 0 1 0 1 0 0 0 0 0 1 1 1 1 1 0 1 0 1 0 0 1\n", - " 1 0 0 1 1 0 0 0 1 1 1 1 0 1 1 0 1 1 1 0 0 1 0 1 0 0 1 1 0 0 0 0 0 0 1 0 0\n", - " 0 0 1 0 1 0 1 0 1 1 1 1 1 0 1 1 1 0 1 0 1 1 0 0 0 0 1 0 1 1 0 1 0 1 1 1 0\n", - " 0 1 0 1 1 1 0 0 1 0 0 1 0 1 1 0 1 0 1 0 0 1 0 1 1 0 1 0 1 0 0 1 1 1 0 1 0\n", - " 1 0 1 1 1 1 1 0 0 1 0 1 0 0 0 0 1 0 0 0 0 0 1 0 0 0 1 0 1 1 0 0 0 0 0 1 0\n", - " 1 1 0 0 0 0 1 1 0 1 0 0 0 0 1 0 1 1 0 1 1 1 0 1 1 0 0 0 0 1 1 1 0 1 0 1 1\n", - " 0 0 0 0 1 1 1 1 1 0 0 1 0 1 0 1 1 0 1 1 1 1 1 1 1 1 0 0 1 1 1 1 1 1 1 1 1\n", - " 1 0 1 1 1 0 1 0 0 0 0 1 0 0 1 1 1 0 1 1 1 1 0 0 0 0 0 1 1 1 0 0 1 0 1 0 0\n", - " 1 0 0 1 0 1 0 1 1 1 0 1 1 0 1 1 0 0 1 0 0 1 1 1 1 0 1 1 0 1 1 1 0 1 1 0 1\n", - " 0 0 1 1 1 1 0 1 0 0 1 1 1 0 1 1 1 1 0 0 1 0 1 0 0 0 1 1 0 1 0 0 1 0 1 0 0\n", - " 1 0 1 1 0 1 1 1 1 0 0 1 0 0 1 1 1 1 0 0 1 1 1 1 0 0 1 0 0 1 0 1 0 1 1 1 0\n", - " 1 0 1 1 0 0 0 0 1 0 0 0 1 1 1 1 0 0 0 1 1 0 1 1 0 0 1 0 0 1 1 1 1 0 0 0 1\n", - " 0 1 0 1 1 0 1 0 0 0 1 1 1 1 1 1 0 1 0 1 1 0 1 0 0 1 1 1 1 0 1 1 0 1 1 1 1\n", - " 1 1 0 0 0 0 0 1 0 1 0 0 0 1 0 1 0 0 1 0 1 1 0 1 0 1 1 1 0 1 1 1 1 0 0 1 1\n", - " 1 1 1 0 1 1 0 1] (896,) 0.12653340353855683 tensor(0.1265, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1152,) [0 0 0 ... 1 1 0] (1152,) 0.15009138463477656 tensor(0.1501, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... True True True] (1920,) [0 0 1 ... 1 0 0] (1920,) 0.13893378955027236 tensor(0.1389, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [1 0 1 ... 1 1 0] (8192,) 0.13583524260280033 tensor(0.1358, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... True True True] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.07954545454545454 tensor(0.0795, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... True True True] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.11778846153846154 tensor(0.1178, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.08138020833333333 tensor(0.0814, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.196875 tensor(0.1969, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (2304,) [0 0 0 ... 0 0 0] (2304,) 0.1623263888888889 tensor(0.1623, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... True True True] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.1292724609375 tensor(0.1293, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (2304,) [0 0 0 ... 0 0 0] (2304,) 0.09678819444444445 tensor(0.0968, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.19587053571428573 tensor(0.1959, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.1015625 tensor(0.1016, device='cuda:0', dtype=torch.float64)\n", - "why [False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False True True True\n", - " True True True True True True True True True True True False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False True True\n", - " True True True True False False False False False False True True\n", - " True True True True True True True True True True True True\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False True True True True True True True\n", - " True True True True True True False False False False False False\n", - " False False False False True True True True True True True True\n", - " True True True True True True False False False False False False\n", - " False True True True True True True True True True True True\n", - " True True True True True True True False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False] (512,) [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] (512,) 0.154296875 tensor(0.1543, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.142578125 tensor(0.1426, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.1318359375 tensor(0.1318, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (2432,) [0 0 0 ... 0 0 0] (2432,) 0.09580592105263158 tensor(0.0958, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.2506510416666667 tensor(0.2507, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1152,) [0 0 0 ... 0 0 0] (1152,) 0.09809027777777778 tensor(0.0981, device='cuda:0', dtype=torch.float64)\n", - "why [ True False False ... False False False] (1024,) [0 0 0 ... 0 0 0] (1024,) 0.2353515625 tensor(0.2354, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.109375 tensor(0.1094, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.14599609375 tensor(0.1460, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... True True True] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.10107421875 tensor(0.1011, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.20454545454545456 tensor(0.2045, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1024,) [0 0 0 ... 0 0 0] (1024,) 0.099609375 tensor(0.0996, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.19810267857142858 tensor(0.1981, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.10885416666666667 tensor(0.1089, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... True True True] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.1417236328125 tensor(0.1417, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.11221590909090909 tensor(0.1122, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.13040865384615385 tensor(0.1304, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... True True True] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.09588068181818182 tensor(0.0959, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.16471354166666666 tensor(0.1647, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (2176,) [0 0 0 ... 0 0 0] (2176,) 0.14935661764705882 tensor(0.1494, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.1328125 tensor(0.1328, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.10997596153846154 tensor(0.1100, device='cuda:0', dtype=torch.float64)\n", - "why [ True False False ... False False False] (1152,) [0 0 0 ... 0 0 0] (1152,) 0.1232638888888889 tensor(0.1233, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.13616071428571427 tensor(0.1362, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... True True True] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.184375 tensor(0.1844, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (2304,) [0 0 0 ... 0 0 0] (2304,) 0.14149305555555555 tensor(0.1415, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.1380615234375 tensor(0.1381, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.1 tensor(0.1000, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.15980113636363635 tensor(0.1598, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.11160714285714286 tensor(0.1116, device='cuda:0', dtype=torch.float64)\n", - "why [False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False True True True True True True True True True\n", - " True True True True True True True True True True True True\n", - " True True True True False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " True True True True True True True True True True True True\n", - " True False True True True True True True True True True True\n", - " True True True True False False False False False False True True\n", - " True True True True True True True True True True True True\n", - " True True True True False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False True True True\n", - " True True True True True True True True True True True False\n", - " False False False False True True True True True True True True\n", - " True True True True True True True True True True False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False True\n", - " True True True True True True True True True True True True\n", - " True False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " True True True False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False True True True True True True True True True True True\n", - " True True True False False True True True True True True True\n", - " True True True True True True True True True True True True\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False True True True True True\n", - " True True True True True True True True True False False False\n", - " False False False False False False False False False False False False] (768,) [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] (768,) 0.21614583333333334 tensor(0.2161, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (2944,) [0 0 0 ... 0 0 0] (2944,) 0.1328125 tensor(0.1328, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.135498046875 tensor(0.1355, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.13385416666666666 tensor(0.1339, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.15178571428571427 tensor(0.1518, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.1203125 tensor(0.1203, device='cuda:0', dtype=torch.float64)\n", - "why [False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False True True True True True True True True\n", - " True True True True True True False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False True True True True True True True True True\n", - " True True True True True False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False True True True True\n", - " False False True True True True True True True True False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False True True True True True True True True\n", - " True True True True True True False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " True True True True True True True True True True True True\n", - " True True False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " True True True True True True True True True True True True\n", - " True False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False] (896,) [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0] (896,) 0.09040178571428571 tensor(0.0904, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.1045673076923077 tensor(0.1046, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.1239013671875 tensor(0.1239, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (2176,) [0 0 0 ... 0 0 0] (2176,) 0.13051470588235295 tensor(0.1305, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.14609375 tensor(0.1461, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.12239583333333333 tensor(0.1224, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.14118303571428573 tensor(0.1412, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.16193181818181818 tensor(0.1619, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.13916015625 tensor(0.1392, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False True True] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.09801136363636363 tensor(0.0980, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (2816,) [0 0 0 ... 0 0 0] (2816,) 0.10404829545454546 tensor(0.1040, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.0875 tensor(0.0875, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.20099431818181818 tensor(0.2010, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.13984375 tensor(0.1398, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False True True] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.1226806640625 tensor(0.1227, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.11145833333333334 tensor(0.1115, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.181640625 tensor(0.1816, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... True True True] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.11484375 tensor(0.1148, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.1658653846153846 tensor(0.1659, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.16685267857142858 tensor(0.1669, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.1483154296875 tensor(0.1483, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (1152,) [0 0 0 ... 0 0 0] (1152,) 0.1302568958818959 tensor(0.1303, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.189453125 tensor(0.1895, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.11067708333333333 tensor(0.1107, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.23468815928270043 tensor(0.2347, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... True True True] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.13385416666666666 tensor(0.1339, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.16484484726123597 tensor(0.1648, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.11263020833333333 tensor(0.1126, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.15576171875 tensor(0.1558, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.11610949612403101 tensor(0.1161, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.17734375 tensor(0.1773, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.171875 tensor(0.1719, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.1438577872555272 tensor(0.1439, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... True True True] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.12439903846153846 tensor(0.1244, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.22088068181818182 tensor(0.2209, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.15223817567567566 tensor(0.1522, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.1356534090909091 tensor(0.1357, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (2176,) [0 0 0 ... 0 0 0] (2176,) 0.15503202814868278 tensor(0.1550, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... True True True] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.15613628135565832 tensor(0.1561, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.08984375 tensor(0.0898, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.14543269230769232 tensor(0.1454, device='cuda:0', dtype=torch.float64)\n", - "why [False True True ... False False False] (1152,) [0 0 0 ... 0 0 0] (1152,) 0.1545138888888889 tensor(0.1545, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.16601859327507598 tensor(0.1660, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.09486607142857142 tensor(0.0949, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.1307907754109508 tensor(0.1308, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [ True True True True True True True True True True True True\n", - " True True True False True True True True True True True True\n", - " True True True False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False True True True\n", - " True True True True True True True True True False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False True True True True True True\n", - " False False False True True True True True True False False False\n", - " False False False False False False False False False False True True\n", - " True True True True True True True True True True False False\n", - " False False False False False False False False False False False False\n", - " False True True True True True True True True True True True\n", - " True True True True False False False False True True True True\n", - " True True True True True True True False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " True True True True True True True True True True True True\n", - " True True False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False True True True True\n", - " True True True True True True True False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False True True True True True True True\n", - " True True True True False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False True True True True True True True\n", - " True True True True True False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False True True True True True True True True True True True\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " True True True True True True True True True True True False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False True True True True True True\n", - " True True True True True False False False False False False False\n", - " False False False False False False False False] (896,) [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0] (896,) 0.18861607142857142 tensor(0.1886, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (2560,) [0 0 0 ... 0 0 0] (2560,) 0.2031711368110236 tensor(0.2032, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.12560096153846154 tensor(0.1256, device='cuda:0', dtype=torch.float64)\n", - "why [False False False False False False True True True True True True\n", - " True True True True True True True True False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False True True True True True True True True\n", - " True True True True True False False False False False False False\n", - " False False False False False True True True True True True True\n", - " True True True True True True True False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False True True True True True\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False True True True True True True True True True\n", - " True True True True True True True True False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False True True True True True True True True True\n", - " True True True True False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False True True True True True True True True\n", - " True True True True True True False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False] (768,) [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] (768,) 0.1171875 tensor(0.1172, device='cuda:0', dtype=torch.float64)\n", - "why [False False True ... False False False] (2304,) [0 0 0 ... 0 0 0] (2304,) 0.14322916666666666 tensor(0.1432, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.16135470753972053 tensor(0.1614, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.13365334378265414 tensor(0.1337, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.14312537741545892 tensor(0.1431, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.1383054595896147 tensor(0.1383, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.1884765625 tensor(0.1885, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... True True True] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.0889423076923077 tensor(0.0889, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.1392934035570018 tensor(0.1393, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False True] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.1148314123790117 tensor(0.1148, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.234375 tensor(0.2344, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... True True True] (1152,) [0 0 0 ... 0 0 0] (1152,) 0.13518363161819538 tensor(0.1352, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... True True True] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.1484375 tensor(0.1484, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... True True True] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.20069918995700245 tensor(0.2007, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... True True True] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.16518310916225415 tensor(0.1652, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.14157774390243902 tensor(0.1416, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.16829982517482517 tensor(0.1683, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... True True True] (1536,) [0 0 0 ... 0 1 1] (1536,) 0.12203414351851852 tensor(0.1220, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.14015534682080924 tensor(0.1402, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.18266864778921865 tensor(0.1827, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.15107465864301803 tensor(0.1511, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.12805550230061352 tensor(0.1281, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.1798145077383275 tensor(0.1798, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... True True True] (1536,) [0 0 0 ... 1 0 0] (1536,) 0.14846865031897927 tensor(0.1485, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.2182291666666667 tensor(0.2182, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.17879293893129775 tensor(0.1788, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.17408185325186412 tensor(0.1741, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.13385180995475113 tensor(0.1339, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1152,) [0 0 0 ... 0 0 0] (1152,) 0.1634497549019608 tensor(0.1634, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (1152,) [0 0 0 ... 0 0 0] (1152,) 0.2111472315436242 tensor(0.2111, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.17961774553571427 tensor(0.1796, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (2176,) [0 0 0 ... 0 0 0] (2176,) 0.19505408546397282 tensor(0.1951, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.17720760641838973 tensor(0.1772, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.17260322523480418 tensor(0.1726, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.19631456413210446 tensor(0.1963, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.16002286585365852 tensor(0.1600, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... True True True] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.15676843030872636 tensor(0.1568, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.19980746809032893 tensor(0.1998, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.17694871945488722 tensor(0.1769, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... True True False] (3200,) [0 0 0 ... 0 0 0] (3200,) 0.17646062940470833 tensor(0.1765, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.21987862976406533 tensor(0.2199, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (1536,) [1 0 0 ... 0 0 0] (1536,) 0.22485079470618036 tensor(0.2249, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1152,) [0 0 0 ... 0 0 0] (1152,) 0.18892249103942654 tensor(0.1889, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1024,) [0 0 0 ... 0 0 0] (1024,) 0.20539447623239437 tensor(0.2054, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... True True False] (8192,) [1 0 0 ... 0 0 0] (8192,) 0.1956759851363835 tensor(0.1957, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [ True True True ... False False False] (2560,) [1 1 1 ... 0 0 0] (2560,) 0.16270833333333334 tensor(0.1627, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... True True True] (1536,) [0 0 0 ... 1 0 0] (1536,) 0.28461934747103557 tensor(0.2846, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.2885416666666667 tensor(0.2885, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.24493883087633087 tensor(0.2449, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1152,) [0 0 0 ... 0 0 0] (1152,) 0.1776813162682728 tensor(0.1777, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (8192,) [1 1 1 ... 0 0 0] (8192,) 0.22326946266948078 tensor(0.2233, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.19251085890430153 tensor(0.1925, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.22502709178398156 tensor(0.2250, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.22283878504672897 tensor(0.2228, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (1024,) [0 0 0 ... 0 0 0] (1024,) 0.2044723429144385 tensor(0.2045, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (2304,) [0 0 0 ... 0 0 0] (2304,) 0.1890666335978836 tensor(0.1891, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.1970471833881579 tensor(0.1970, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.17814201811043567 tensor(0.1781, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.22176106178589622 tensor(0.2218, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.15586979984301413 tensor(0.1559, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.19933712121212122 tensor(0.1993, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.27441314553990614 tensor(0.2744, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.20748284786370724 tensor(0.2075, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.1627858889528193 tensor(0.1628, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (1152,) [0 0 0 ... 0 0 0] (1152,) 0.3289409447955064 tensor(0.3289, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (1408,) [0 1 0 ... 0 0 0] (1408,) 0.25750782574670666 tensor(0.2575, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (2176,) [0 0 0 ... 0 0 0] (2176,) 0.2380265050832091 tensor(0.2380, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.20203645462301223 tensor(0.2020, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.2304055108248235 tensor(0.2304, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [ True True True ... False False False] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.18352952167414052 tensor(0.1835, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.391332129896404 tensor(0.3913, device='cuda:0', dtype=torch.float64)\n", - "why [False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False True\n", - " True True True True True True True True True True False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False True True True True True True True True\n", - " True True True True True True True True True True True True\n", - " True True True False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " True True True True True True True True True True True False\n", - " False True True True True True True True True True True True\n", - " True False True True True True True True True True True True\n", - " True False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False True True True True True True\n", - " True True True True True True False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False True True True True\n", - " True True True True True True True True False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False True True True True True True True True\n", - " True True False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " True True True True True True True True True True True True\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " True True True True True True True True True False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False] (896,) [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 1 1 1 1 1 1 1 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 1 1 1 0 1 1 0 0 0 0 0 1 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1\n", - " 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1\n", - " 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0] (896,) 0.43876971003366205 tensor(0.4388, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.26807482215447154 tensor(0.2681, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (2176,) [0 0 0 ... 0 0 0] (2176,) 0.2458394306739895 tensor(0.2458, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.2842815311314583 tensor(0.2843, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (1152,) [0 0 0 ... 0 0 0] (1152,) 0.24575731426692965 tensor(0.2458, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.276717519724741 tensor(0.2767, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (2048,) [0 0 0 ... 0 0 0] (2048,) 0.38168526600954644 tensor(0.3817, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True True True True True True True True True True\n", - " True True True True True True False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False True True True True True True\n", - " True True True True True True True True True True True False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False True True True True True\n", - " True True True True True True True True False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False True True True True True True True True True True\n", - " True True True True False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False True True True True True True True True\n", - " True True True True True True True True True True False False\n", - " False False False False False True True True True True True True\n", - " True True True True True True True True True True True False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False True True True\n", - " True True True True True True True True True True False False\n", - " False True True True True True True True True True True True\n", - " True True False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False True True\n", - " True True True True True True True True True True True True\n", - " True True True True True False False False False False False False\n", - " False False False False True True True True True True True True\n", - " True True True True False False True True True True True True\n", - " True True True True True True True True True True True True\n", - " True False False False False True True True True True True True\n", - " True True True True True True False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " True True True True True True True True True True True True\n", - " True True True True True True True True False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False True True True True True True True True\n", - " True True True True True True True True True True False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False] (896,) [0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 0 1 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 1\n", - " 1 1 1 1 1 1 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 1 0 1 1 1 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0] (896,) 0.3275530937683716 tensor(0.3276, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (2304,) [0 0 0 ... 0 0 0] (2304,) 0.24250047241118666 tensor(0.2425, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.27537596564595973 tensor(0.2754, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (2304,) [0 0 0 ... 0 0 0] (2304,) 0.339521139314602 tensor(0.3395, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.28316756119010217 tensor(0.2832, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1024,) [0 1 0 ... 0 0 0] (1024,) 0.30224860634648365 tensor(0.3022, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.23057474330872174 tensor(0.2306, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... True True True] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.22791799898259513 tensor(0.2279, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.27437629915291323 tensor(0.2744, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (2176,) [0 0 0 ... 0 0 0] (2176,) 0.21319969405140976 tensor(0.2132, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.3474399687036469 tensor(0.3474, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... True True True] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.19751082251082253 tensor(0.1975, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1152,) [0 0 0 ... 0 0 0] (1152,) 0.3353790123844628 tensor(0.3354, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... True True True] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.24501893939393937 tensor(0.2450, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... True True True] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.2624466475767001 tensor(0.2624, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (1408,) [1 1 1 ... 0 0 0] (1408,) 0.22450973341004987 tensor(0.2245, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... True False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.27231664754255114 tensor(0.2723, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1024,) [0 0 0 ... 0 0 0] (1024,) 0.3152901785714286 tensor(0.3153, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... True True True] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.35922695360195356 tensor(0.3592, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... True True True] (2432,) [0 0 0 ... 0 0 0] (2432,) 0.26736473289421736 tensor(0.2674, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [1 1 1 ... 0 0 0] (8192,) 0.28538833123099405 tensor(0.2854, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.245172509039775 tensor(0.2452, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.24340502699055327 tensor(0.2434, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.28707033026885964 tensor(0.2871, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (1024,) [0 0 0 ... 0 0 0] (1024,) 0.2957705135233918 tensor(0.2958, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.2895262781476896 tensor(0.2895, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.2656280862586716 tensor(0.2656, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... True True False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.19391790985177615 tensor(0.1939, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... True True True] (1152,) [0 0 0 ... 1 1 1] (1152,) 0.39839248075956224 tensor(0.3984, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.3400271739130435 tensor(0.3400, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.26218694096601075 tensor(0.2622, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (2304,) [0 0 0 ... 0 0 0] (2304,) 0.25949223766281415 tensor(0.2595, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.2731843170244799 tensor(0.2732, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (2432,) [0 0 0 ... 0 0 0] (2432,) 0.23153263758670284 tensor(0.2315, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.3294548915822105 tensor(0.3295, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1024,) [0 0 0 ... 0 0 0] (1024,) 0.50768331438611 tensor(0.5077, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.17941607556456285 tensor(0.1794, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.4005733735380117 tensor(0.4006, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.32525391000796444 tensor(0.3253, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.28327316031926486 tensor(0.2833, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (2688,) [0 0 0 ... 0 0 0] (2688,) 0.2455340291329215 tensor(0.2455, device='cuda:0', dtype=torch.float64)\n", - "why [False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False True True True True True True True\n", - " True True True True False False False False False False False False\n", - " False False False False False False False False False False True True\n", - " True True True True True True True True True True True True\n", - " True True False False False False False False False False False False\n", - " False False True True True True True True True True True True\n", - " True True False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False True True True True False False False False\n", - " False False False False False False False False False False False False\n", - " True True True True True True True True True True True True\n", - " True True True True True True True True True True True True\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False True True True True True True True True True True True\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False True True True\n", - " True True True True True True True True False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False True True True True True True True True\n", - " True True True False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False True True True True True True True True True True\n", - " True True False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False] (896,) [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1\n", - " 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 0 0 0\n", - " 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 0 0 0 1 1 1 1 1\n", - " 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0] (896,) 0.36439732142857145 tensor(0.3644, device='cuda:0', dtype=torch.float64)\n", - "why [False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False True True True True True False False False True\n", - " True True True True True True True True True True True True\n", - " True True True True True True True True True True True False\n", - " False False False False True True True True True True True True\n", - " True True True False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False True True True True True True True True\n", - " True True True True True True False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " True True True True True True True True True True True True\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False True True True True True True True True True True\n", - " True True True True False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False True True True True True True True\n", - " True True True True True True True False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False False False True True True True\n", - " True True True True True True True True True True True True\n", - " True True True True True True False False False False False False\n", - " False False False False False False False False False False False False\n", - " False False False False False False True True True True True True\n", - " True True True True True True True True False False False False\n", - " False False False False False False False False False False False False] (768,) [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1\n", - " 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0\n", - " 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] (768,) 0.36334134615384617 tensor(0.3633, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.36328125 tensor(0.3633, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.3073375105806347 tensor(0.3073, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.42102430988608963 tensor(0.4210, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.39475473771436803 tensor(0.3948, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.3678160635096611 tensor(0.3678, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.19351388184584178 tensor(0.1935, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1792,) [0 0 0 ... 0 0 0] (1792,) 0.24591191813804175 tensor(0.2459, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.3096451568959731 tensor(0.3096, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.24075629195519133 tensor(0.2408, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1536,) [0 0 0 ... 0 0 0] (1536,) 0.17249526515151514 tensor(0.1725, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.2863095238095238 tensor(0.2863, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.4012790080941676 tensor(0.4013, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (1920,) [0 0 0 ... 0 0 0] (1920,) 0.32623064828253506 tensor(0.3262, device='cuda:0', dtype=torch.float64)\n", - "why [ True True True ... False False False] (8192,) [0 0 0 ... 0 0 0] (8192,) 0.29373969403168476 tensor(0.2937, device='cuda:0', dtype=torch.float64)\n", - "torch.Size([8192]) torch.Size([8192]) tensor(8064, device='cuda:0') torch.Size([64, 128]) torch.Size([64, 128])\n", - "why [False False False ... False False False] (2176,) [0 0 0 ... 0 0 0] (2176,) 0.3421500286608995 tensor(0.3422, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... True True True] (1664,) [0 0 0 ... 1 1 0] (1664,) 0.22848216513818703 tensor(0.2285, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1280,) [0 0 0 ... 0 0 0] (1280,) 0.21294610507246378 tensor(0.2129, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1408,) [0 0 0 ... 0 0 0] (1408,) 0.4324312010246706 tensor(0.4324, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... False False False] (1664,) [0 0 0 ... 0 0 0] (1664,) 0.42839099459862173 tensor(0.4284, device='cuda:0', dtype=torch.float64)\n", - "why [False False False ... True True True] (8192,) [0 0 0 ... 1 1 0] (8192,) 0.3411826173375903 tensor(0.3412, device='cuda:0', dtype=torch.float64)\n" - ] - }, - { - "ename": "KeyboardInterrupt", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0mbest_val_metric\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 25\u001b[0;31m train(\n\u001b[0m\u001b[1;32m 26\u001b[0m \u001b[0malgorithm\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0malgorithm\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0mdatasets\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdatasets\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/wilds/examples/train.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(algorithm, datasets, general_logger, config, epoch_offset, best_val_metric)\u001b[0m\n\u001b[1;32m 88\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 89\u001b[0m \u001b[0;31m# First run training\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 90\u001b[0;31m \u001b[0mrun_epoch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0malgorithm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdatasets\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'train'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgeneral_logger\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepoch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconfig\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 91\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[0;31m# Then run val\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/wilds/examples/train.py\u001b[0m in \u001b[0;36mrun_epoch\u001b[0;34m(algorithm, dataset, general_logger, epoch, config, train)\u001b[0m\n\u001b[1;32m 41\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0miterator\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 42\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 43\u001b[0;31m \u001b[0mbatch_results\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0malgorithm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 44\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[0mbatch_results\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0malgorithm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mevaluate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/wilds/examples/algorithms/single_model_algorithm.py\u001b[0m in \u001b[0;36mupdate\u001b[0;34m(self, batch)\u001b[0m\n\u001b[1;32m 101\u001b[0m \u001b[0;31m# process batch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 102\u001b[0m \u001b[0mresults\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprocess_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 103\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_update\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 104\u001b[0m \u001b[0;31m# log results\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 105\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate_log\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresults\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/wilds/examples/algorithms/single_model_algorithm.py\u001b[0m in \u001b[0;36m_update\u001b[0;34m(self, results)\u001b[0m\n\u001b[1;32m 120\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax_grad_norm\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 121\u001b[0m \u001b[0mclip_grad_norm_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparameters\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmax_grad_norm\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 122\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 123\u001b[0m self.step_schedulers(\n\u001b[1;32m 124\u001b[0m \u001b[0mis_epoch\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/site-packages/torch/autograd/grad_mode.py\u001b[0m in \u001b[0;36mdecorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mdecorate_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__class__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 26\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 27\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mcast\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mF\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdecorate_context\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/site-packages/torch/optim/adam.py\u001b[0m in \u001b[0;36mstep\u001b[0;34m(self, closure)\u001b[0m\n\u001b[1;32m 106\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 107\u001b[0m \u001b[0mbeta1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbeta2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgroup\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'betas'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 108\u001b[0;31m F.adam(params_with_grad,\n\u001b[0m\u001b[1;32m 109\u001b[0m \u001b[0mgrads\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 110\u001b[0m \u001b[0mexp_avgs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda2/envs/wilds-final-3/lib/python3.8/site-packages/torch/optim/functional.py\u001b[0m in \u001b[0;36madam\u001b[0;34m(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad, beta1, beta2, lr, weight_decay, eps)\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[0mdenom\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mmax_exp_avg_sq\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqrt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mmath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqrt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbias_correction2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0meps\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 93\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 94\u001b[0;31m \u001b[0mdenom\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mexp_avg_sq\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqrt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mmath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqrt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbias_correction2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0meps\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 95\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 96\u001b[0m \u001b[0mstep_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlr\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mbias_correction1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mKeyboardInterrupt\u001b[0m: " - ] - } - ], - "source": [ - "if not config.eval_only:\n", - " ## Load saved results if resuming\n", - " resume_success = False\n", - " if resume:\n", - " save_path = os.path.join(config.log_dir, 'last_model.pth')\n", - " if not os.path.exists(save_path):\n", - " epochs = [\n", - " int(file.split('_')[0])\n", - " for file in os.listdir(config.log_dir) if file.endswith('.pth')]\n", - " if len(epochs) > 0:\n", - " latest_epoch = max(epochs)\n", - " save_path = os.path.join(config.log_dir, f'{latest_epoch}_model.pth')\n", - " try:\n", - " prev_epoch, best_val_metric = load(algorithm, save_path)\n", - " epoch_offset = prev_epoch + 1\n", - " logger.write(f'Resuming from epoch {epoch_offset} with best val metric {best_val_metric}')\n", - " resume_success = True\n", - " except FileNotFoundError:\n", - " pass\n", - "\n", - " if resume_success == False:\n", - " epoch_offset=0\n", - " best_val_metric=None\n", - " \n", - " train(\n", - " algorithm=algorithm,\n", - " datasets=datasets,\n", - " general_logger=logger,\n", - " config=config,\n", - " epoch_offset=epoch_offset,\n", - " best_val_metric=best_val_metric)\n", - "else:\n", - " if config.eval_epoch is None:\n", - " eval_model_path = os.path.join(config.log_dir, 'best_model.pth')\n", - " else:\n", - " eval_model_path = os.path.join(config.log_dir, f'{config.eval_epoch}_model.pth')\n", - " best_epoch, best_val_metric = load(algorithm, eval_model_path)\n", - " if config.eval_epoch is None:\n", - " epoch = best_epoch\n", - " else:\n", - " epoch = config.eval_epoch\n", - " evaluate(\n", - " algorithm=algorithm,\n", - " datasets=datasets,\n", - " epoch=epoch,\n", - " general_logger=logger,\n", - " config=config)\n", - "\n", - "logger.close()\n", - "for split in datasets:\n", - " datasets[split]['eval_logger'].close()\n", - " datasets[split]['algo_logger'].close()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.5" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} From f4b5417550dd10d6f9aa08242f3451982ba9de3a Mon Sep 17 00:00:00 2001 From: aikanor Date: Sun, 21 Mar 2021 07:41:40 -0700 Subject: [PATCH 081/244] merge with dev branch --- sbox_run_expt.ipynb | 12 ++++++------ wilds/datasets/encodetfbs_dataset.py | 17 +---------------- 2 files changed, 7 insertions(+), 22 deletions(-) diff --git a/sbox_run_expt.ipynb b/sbox_run_expt.ipynb index 6d1a135a..39f5e862 100644 --- a/sbox_run_expt.ipynb +++ b/sbox_run_expt.ipynb @@ -563,23 +563,23 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 2", + "display_name": "Python 3", "language": "python", - "name": "python2" + "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", - "version": 2 + "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.13" + "pygments_lexer": "ipython3", + "version": "3.8.5" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index 588b9fce..ffd537b3 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -138,22 +138,7 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): self._metric = MultiTaskAccuracy() super().__init__(root_dir, download, split_scheme) - - """ - def get_random_label_vec(metadata_df, output_size=128): - # Sample a positively labeled region at random - pos_mdf = metadata_df[metadata_df['y'] == 1] #.iloc[ metadata_df['chr'] == s['chr'], : ] - pos_seed_region = pos_mdf.iloc[np.random.randint(pos_mdf.shape[0])] - - # Extract regions from this chromosome in this celltype, to get a window of labels from - chr_msk = np.array(metadata_df['chr']) == pos_seed_region['chr'] - ct_msk = np.array(metadata_df['celltype']) == pos_seed_region['celltype'] - mdf = metadata_df[chr_msk & ct_msk] - - # Get labels - start_ndx = np.where(mdf['start'] == pos_seed_region['start'])[0][0] - y_label_vec = mdf.iloc[start_ndx:start_ndx+output_size, :]['y'] - """ + def get_input(self, idx, window_size=12800): """ From 3d4ce3aa182d5e1ec2add17fdaa6774b41c06955 Mon Sep 17 00:00:00 2001 From: Tony Lee Date: Sun, 21 Mar 2021 11:29:24 -0700 Subject: [PATCH 082/244] Address PR comments --- examples/configs/datasets.py | 15 ---- examples/evaluate.py | 109 +++++++++++++++------------ wilds/datasets/fmow_dataset.py | 2 +- wilds/datasets/ogbmolpcba_dataset.py | 2 +- 4 files changed, 62 insertions(+), 66 deletions(-) diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index fedbbde9..cd2d1d6f 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -284,21 +284,6 @@ }, } -####################################### -### List of the main WILDS datasets ### -####################################### - -main_datasets = [ - 'amazon', - 'camelyon17', - 'civilcomments', - 'fmow', - 'iwildcam', - 'ogb-molpcba', - 'py150', - 'poverty', -] - ########################################## ### Split-specific defaults for Amazon ### ########################################## diff --git a/examples/evaluate.py b/examples/evaluate.py index 6cefda37..c9224260 100644 --- a/examples/evaluate.py +++ b/examples/evaluate.py @@ -4,13 +4,13 @@ import sys import urllib.request from ast import literal_eval -from typing import Any, Dict, List, Union +from typing import Any, Dict, List from urllib.parse import urlparse import numpy as np import torch -from configs.datasets import main_datasets +from wilds import benchmark_datasets from wilds import get_dataset from wilds.datasets.wilds_dataset import WILDSDataset, WILDSSubset @@ -26,58 +26,68 @@ """ -def evaluate_all(path: str, output_path: str, dataset_path: str): +def evaluate_all_benchmarks(predictions_dir: str, output_dir: str, root_dir: str): """ - Evaluate for all the WILDS datasets. + Evaluate predictions for all the WILDS benchmarks. Parameters: - path (str): Path to the directory with predictions. Can be a URL - output_path (str): Output directory - dataset_path (str): Path to the dataset directory + predictions_dir (str): Path to the directory with predictions. Can be a URL + output_dir (str): Output directory + root_dir (str): The directory where datasets can be found """ all_results: Dict[str, Dict[str, Dict[str, float]]] = dict() - for dataset in main_datasets: - all_results[dataset] = evaluate_multiple_replicates( - dataset, path, output_path, dataset_path - ) + for dataset in benchmark_datasets: + try: + all_results[dataset] = evaluate_benchmark( + dataset, predictions_dir, output_dir, root_dir + ) + except Exception as e: + print(f"Could not evaluate predictions for {dataset}:\n{str(e)}") # Write out aggregated results to output file - print(f"Writing complete results to {output_path}...") - with open(os.path.join(output_path, "all_results.json"), "w") as f: + print(f"Writing complete results to {output_dir}...") + with open(os.path.join(output_dir, "all_results.json"), "w") as f: json.dump(all_results, f, indent=4) -def evaluate_multiple_replicates( - dataset_name: str, path: str, output_path: str, dataset_path: str +def evaluate_benchmark( + dataset_name: str, predictions_dir: str, output_dir: str, root_dir: str ) -> Dict[str, Dict[str, float]]: """ - Evaluate across multiple replicates. + Evaluate across multiple replicates for a single benchmark. Parameters: dataset_name (str): Name of the dataset. See datasets.py for the complete list of datasets. - path (str): Path to the directory with predictions. Can be a URL. - output_path (str): Output directory - dataset_path (str): Path to the dataset directory + predictions_dir (str): Path to the directory with predictions. Can be a URL. + output_dir (str): Output directory + root_dir (str): The directory where datasets can be found Returns: Metrics as a dictionary with metrics as the keys and metric values as the values """ - def get_replicates(dataset_name: str) -> List[Union[str, int]]: - if dataset_name == "camelyon17": - return list(range(0, 10)) - elif dataset_name == "poverty": - return ["A", "B", "C", "D", "E"] + def get_replicates(dataset_name: str) -> List[str]: + if dataset_name == "poverty": + return [f"fold:{fold}" for fold in ["A", "B", "C", "D", "E"]] else: - return list(range(0, 3)) + if dataset_name == "camelyon17": + seeds = range(0, 10) + elif dataset_name == "civilcomments": + seeds = range(0, 5) + else: + seeds = range(0, 3) + return [f"seed:{seed}" for seed in seeds] def get_best_prediction_filename( - dataset_name: str, split: str, replicate: Union[str, int] + predictions_dir: str, dataset_name: str, split: str, replicate: str ) -> str: - if dataset_name == "poverty": - return f"{dataset_name}_split:{split}_fold:{replicate}_epoch:best_pred.csv" - else: - return f"{dataset_name}_split:{split}_seed:{replicate}_epoch:best_pred.csv" + run_id = f"{dataset_name}_split:{split}_{replicate}" + for file in os.listdir(predictions_dir): + if file.startswith(run_id) and file.endswith(".csv"): + return file + raise FileNotFoundError( + f"Could not find CSV prediction file that starts with {run_id}." + ) def get_metrics(dataset_name: str) -> List[str]: if "amazon" == dataset_name: @@ -101,14 +111,14 @@ def get_metrics(dataset_name: str) -> List[str]: # Dataset will only be downloaded if it does not exist wilds_dataset: WILDSDataset = get_dataset( - dataset=dataset_name, root_dir=dataset_path, download=True + dataset=dataset_name, root_dir=root_dir, download=True ) - splits: List[str] = wilds_dataset.split_dict.keys() + splits: List[str] = list(wilds_dataset.split_dict.keys()) if "train" in splits: splits.remove("train") replicates_results: Dict[str, Dict[str, List[float]]] = dict() - replicates: List[Union[str, int]] = get_replicates(dataset_name) + replicates: List[str] = get_replicates(dataset_name) metrics: List[str] = get_metrics(dataset_name) # Store the results for each replicate @@ -119,17 +129,17 @@ def get_metrics(dataset_name: str) -> List[str]: for replicate in replicates: predictions_file = get_best_prediction_filename( - dataset_name, split, replicate + predictions_dir, dataset_name, split, replicate ) print( f"Processing split={split}, replicate={replicate}, predictions_file={predictions_file}..." ) - full_path = os.path.join(path, predictions_file) + full_path = os.path.join(predictions_dir, predictions_file) predicted_labels: List[Any] = get_predictions(full_path) predicted_labels_tensor: torch.Tensor = torch.from_numpy( np.array(predicted_labels) ) - metric_results: Dict[str, float] = evaluate( + metric_results: Dict[str, float] = evaluate_replicate( wilds_dataset, split, predicted_labels_tensor ) for metric in metrics: @@ -148,14 +158,14 @@ def get_metrics(dataset_name: str) -> List[str]: aggregated_results[split][metric] = np.mean(replicates_metric_values) # Write out aggregated results to output file - print(f"Writing aggregated results for {dataset_name} to {output_path}...") - with open(os.path.join(output_path, f"{dataset_name}_results.json"), "w") as f: + print(f"Writing aggregated results for {dataset_name} to {output_dir}...") + with open(os.path.join(output_dir, f"{dataset_name}_results.json"), "w") as f: json.dump(aggregated_results, f, indent=4) return aggregated_results -def evaluate( +def evaluate_replicate( dataset: WILDSDataset, split: str, predicted_labels: torch.Tensor ) -> Dict[str, float]: """ @@ -173,8 +183,9 @@ def evaluate( subset: WILDSSubset = dataset.get_subset(split) true_labels: torch.Tensor = subset.y_array metadata: torch.Tensor = subset.metadata_array - # Attempt to resize predicted_labels tensor to match true_labels tensor's shape - predicted_labels.resize_(true_labels.shape) + # predicted_labels.resize_(true_labels.shape) + if predicted_labels.shape != true_labels.shape: + predicted_labels.unsqueeze_(-1) return dataset.eval(predicted_labels, true_labels, metadata)[0] @@ -212,12 +223,12 @@ def is_path_url(path: str) -> bool: def main(): if args.dataset: - evaluate_multiple_replicates( - args.dataset, args.path, args.output_path, args.dataset_path + evaluate_benchmark( + args.dataset, args.predictions_dir, args.output_dir, args.root_dir ) else: print("A dataset was not specified. Evaluating for all WILDS datasets...") - evaluate_all(args.path, args.output_path, args.dataset_path) + evaluate_all_benchmarks(args.predictions_dir, args.output_dir, args.root_dir) print("\nDone.") @@ -226,26 +237,26 @@ def main(): description="Evaluate predictions for WILDS datasets." ) parser.add_argument( - "path", + "predictions_dir", type=str, help="Path to prediction CSV files.", ) parser.add_argument( - "output_path", + "output_dir", type=str, help="Path to output directory.", ) parser.add_argument( "--dataset", type=str, - choices=main_datasets, + choices=benchmark_datasets, help="WILDS dataset to evaluate for.", ) parser.add_argument( - "--dataset-path", + "--root-dir", type=str, default="data", - help="Path to dataset. Defaults to `data` if not specified.", + help="The directory where the datasets can be found (or should be downloaded to, if they do not exist).", ) # Parse args and run this script diff --git a/wilds/datasets/fmow_dataset.py b/wilds/datasets/fmow_dataset.py index 4a310b40..f8e85b9a 100644 --- a/wilds/datasets/fmow_dataset.py +++ b/wilds/datasets/fmow_dataset.py @@ -63,7 +63,7 @@ class FMoWDataset(WILDSDataset): 'compressed_size': 53_893_324_800} } - def __init__(self, version=None, root_dir='data', download=False, split_scheme='official', oracle_training_set=False, seed=111, use_ood_val=False): + def __init__(self, version=None, root_dir='data', download=False, split_scheme='official', oracle_training_set=False, seed=111, use_ood_val=True): self._version = version self._data_dir = self.initialize_data_dir(root_dir, download) diff --git a/wilds/datasets/ogbmolpcba_dataset.py b/wilds/datasets/ogbmolpcba_dataset.py index 413fd330..891e2915 100644 --- a/wilds/datasets/ogbmolpcba_dataset.py +++ b/wilds/datasets/ogbmolpcba_dataset.py @@ -51,7 +51,7 @@ class OGBPCBADataset(WILDSDataset): https://github.com/snap-stanford/ogb/blob/master/LICENSE """ - _dataset_name = 'ogbg-molpcba' + _dataset_name = 'ogb-molpcba' _versions_dict = { '1.0': { 'download_url': None, From e6f794362d0bc280a7ac6738ce2ba5a152ab6396 Mon Sep 17 00:00:00 2001 From: Tony Lee Date: Sun, 21 Mar 2021 11:34:20 -0700 Subject: [PATCH 083/244] Update function name --- examples/evaluate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/evaluate.py b/examples/evaluate.py index c9224260..6ff9eda1 100644 --- a/examples/evaluate.py +++ b/examples/evaluate.py @@ -78,7 +78,7 @@ def get_replicates(dataset_name: str) -> List[str]: seeds = range(0, 3) return [f"seed:{seed}" for seed in seeds] - def get_best_prediction_filename( + def get_prediction_filename( predictions_dir: str, dataset_name: str, split: str, replicate: str ) -> str: run_id = f"{dataset_name}_split:{split}_{replicate}" @@ -128,7 +128,7 @@ def get_metrics(dataset_name: str) -> List[str]: replicates_results[split][metric] = [] for replicate in replicates: - predictions_file = get_best_prediction_filename( + predictions_file = get_prediction_filename( predictions_dir, dataset_name, split, replicate ) print( From 76d179b41aa52094bd791ad487bad70a13573583 Mon Sep 17 00:00:00 2001 From: Tony Lee Date: Sun, 21 Mar 2021 11:35:20 -0700 Subject: [PATCH 084/244] Update function name --- examples/evaluate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/evaluate.py b/examples/evaluate.py index 6ff9eda1..7d957f41 100644 --- a/examples/evaluate.py +++ b/examples/evaluate.py @@ -78,7 +78,7 @@ def get_replicates(dataset_name: str) -> List[str]: seeds = range(0, 3) return [f"seed:{seed}" for seed in seeds] - def get_prediction_filename( + def get_prediction_file( predictions_dir: str, dataset_name: str, split: str, replicate: str ) -> str: run_id = f"{dataset_name}_split:{split}_{replicate}" @@ -128,7 +128,7 @@ def get_metrics(dataset_name: str) -> List[str]: replicates_results[split][metric] = [] for replicate in replicates: - predictions_file = get_prediction_filename( + predictions_file = get_prediction_file( predictions_dir, dataset_name, split, replicate ) print( From 19a0f81e9e15509fd6de75c97a6e641079cbad94 Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Sun, 21 Mar 2021 13:04:18 -0700 Subject: [PATCH 085/244] add encode-tfbs to init --- wilds/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/wilds/__init__.py b/wilds/__init__.py index 77f0ad5a..0a22c780 100644 --- a/wilds/__init__.py +++ b/wilds/__init__.py @@ -10,6 +10,7 @@ 'poverty', 'fmow', 'py150', + 'encode-tfbs' ] additional_datasets = [ From c0f8fa6f91a42fa5c73301c04d3b2cb0e1de5a53 Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Sun, 21 Mar 2021 14:30:28 -0700 Subject: [PATCH 086/244] fix merge issues --- examples/configs/datasets.py | 31 +++---------- examples/configs/model.py | 6 +-- examples/models/CNN_genome.py | 52 +++++++++++----------- examples/models/initializer.py | 2 +- examples/run_expt.py | 2 +- wilds/datasets/encodetfbs_dataset.py | 66 +++++++++++++++------------- wilds/get_dataset.py | 6 ++- 7 files changed, 77 insertions(+), 88 deletions(-) diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index e3f6b2e4..056b508e 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -107,38 +107,17 @@ }, 'encode-tfbs': { 'split_scheme': 'official', - 'model': 'leopard', - 'model_kwargs': {'pretrained': False}, + 'model': 'unet-seq', + 'model_kwargs': {'n_channels_in': 5}, 'train_transform': None, 'eval_transform': None, 'loss_function': 'multitask_bce', 'groupby_fields': ['celltype'], - 'val_metric': 'acc_avg', + 'val_metric': 'avgprec-macro_all', 'val_metric_decreasing': False, - 'optimizer': 'Adam', - 'scheduler': None, - 'batch_size': 64, - 'lr': 0.001, - 'weight_decay': 0.01, - 'n_epochs': 5, - 'n_groups_per_batch': 2, - 'algo_log_metric': 'multitask_avgprec', - # 'irm_lambda': 1.0, - # 'coral_penalty_weight': 0.1, - }, - 'encode-tfbs': { - 'split_scheme': 'official', - 'model': 'leopard', - 'model_kwargs': {'pretrained': False}, - 'train_transform': None, - 'eval_transform': None, - 'loss_function': 'multitask_bce', - 'groupby_fields': ['celltype'], - 'val_metric': 'acc_avg', - 'val_metric_decreasing': False, - 'optimizer': 'Adam', + 'optimizer': 'Adam', 'scheduler': None, - 'batch_size': 64, + 'batch_size': 128, #64, 'lr': 0.001, 'weight_decay': 0.01, 'n_epochs': 5, diff --git a/examples/configs/model.py b/examples/configs/model.py index a4df713b..3356255c 100644 --- a/examples/configs/model.py +++ b/examples/configs/model.py @@ -36,8 +36,8 @@ 'resnet18_ms': { 'target_resolution': (224, 224), }, - 'logistic_regression': {}, - 'leopard': { + 'logistic_regression': {}, + 'unet-seq': { 'optimizer': 'Adam' - }, + }, } diff --git a/examples/models/CNN_genome.py b/examples/models/CNN_genome.py index 7397eeb2..c767f61d 100644 --- a/examples/models/CNN_genome.py +++ b/examples/models/CNN_genome.py @@ -8,27 +8,27 @@ def single_conv(in_channels, out_channels): return nn.Sequential( - nn.Conv1d(in_channels, out_channels, 7, padding=3), - nn.BatchNorm1d(out_channels), + nn.Conv1d(in_channels, out_channels, 7, padding=3), + nn.BatchNorm1d(out_channels), nn.ReLU(inplace=True) ) def double_conv(in_channels, out_channels): return nn.Sequential( - nn.Conv1d(in_channels, out_channels, 7, padding=3), - nn.BatchNorm1d(out_channels), + nn.Conv1d(in_channels, out_channels, 7, padding=3), + nn.BatchNorm1d(out_channels), nn.ReLU(inplace=True), - nn.Conv1d(out_channels, out_channels, 7, padding=3), - nn.BatchNorm1d(out_channels), + nn.Conv1d(out_channels, out_channels, 7, padding=3), + nn.BatchNorm1d(out_channels), nn.ReLU(inplace=True) ) class UNet(nn.Module): - - def __init__(self, out_features=16, n_channels_in=6): + # TODO: This is currently hard-coded to not use out_features + def __init__(self, out_features=16, n_channels_in=5): super().__init__() - + self.dconv_down1 = double_conv(n_channels_in, 15) self.dconv_down2 = double_conv(15, 22) self.dconv_down3 = double_conv(22, 33) @@ -41,7 +41,7 @@ def __init__(self, out_features=16, n_channels_in=6): self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) self.conv_middle = single_conv(109, 109) self.upsamp_6 = nn.ConvTranspose1d(109, 109, 2, stride=2) - + self.dconv_up5 = double_conv(73 + 109, 73) self.upsamp_5 = nn.ConvTranspose1d(73, 73, 2, stride=2) self.dconv_up4 = double_conv(49 + 73, 49) @@ -52,10 +52,10 @@ def __init__(self, out_features=16, n_channels_in=6): self.upsamp_2 = nn.ConvTranspose1d(22, 22, 2, stride=2) self.dconv_up1 = double_conv(15 + 22, 15) self.upsamp_1 = nn.ConvTranspose1d(15, 15, 2, stride=2) - + self.conv_last = nn.Conv1d(15, 1, 200, stride=50, padding=0) - - + + def forward(self, x): # input_size = 12800 # input_channels = 6 @@ -64,32 +64,32 @@ def forward(self, x): conv2 = self.dconv_down2(x) # (input_size / 2) x 22 x = self.maxpool(conv2) # (input_size / 4) x 22 - + conv3 = self.dconv_down3(x) # (input_size / 4) x 33 x = self.maxpool(conv3) # (input_size / 8) x 33 - + conv4 = self.dconv_down4(x) # (input_size / 8) x 49 x = self.maxpool(conv4) # (input_size / 16) x 49 - + conv5 = self.dconv_down5(x) # (input_size / 16) x 73 x = self.maxpool(conv5) # (input_size / 32) x 73 - + conv6 = self.dconv_down6(x) # (input_size / 32) x 109 - # conv6 = self.conv_middle(conv6) # Optional: convolution here. - + # conv6 = self.conv_middle(conv6) # Optional: convolution here. + # Encoder finished. - + x = self.upsamp_6(conv6) # (input_size / 16) x 109 x = torch.cat([x, conv5], dim=1) # (input_size / 16) x (109 + 73) - + x = self.dconv_up5(x) # (input_size / 16) x 73 x = self.upsamp_5(x) # (input_size / 8) x 73 x = torch.cat([x, conv4], dim=1) # (input_size / 8) x (73 + 49) - + x = self.dconv_up4(x) # (input_size / 8) x 49 x = self.upsamp_4(x) # (input_size / 4) x 49 x = torch.cat([x, conv3], dim=1) # (input_size / 4) x (49 + 33) - + x = self.dconv_up3(x) # (input_size / 4) x 33 x = self.upsamp_3(x) # (input_size / 2) x 33 x = torch.cat([x, conv2], dim=1) # (input_size / 2) x (33 + 22) @@ -97,10 +97,10 @@ def forward(self, x): x = self.dconv_up2(x) # (input_size / 2) x 22 x = self.upsamp_2(x) # (input_size) x 22 x = torch.cat([x, conv1], dim=1) # (input_size) x (22 + 15) - + x = self.dconv_up1(x) # (input_size) x 15 - + # middle 128 bits out = self.conv_last(x)[:, :, 64:192] - + return torch.squeeze(out) diff --git a/examples/models/initializer.py b/examples/models/initializer.py index 38615e3b..4f6fef7e 100644 --- a/examples/models/initializer.py +++ b/examples/models/initializer.py @@ -77,7 +77,7 @@ def initialize_model(config, d_out, is_featurizer=False): if is_featurizer: raise NotImplementedError("Featurizer not supported for UNet") else: - model = UNet(out_features=d_out) + model = UNet(out_features=d_out, **config.model_kwargs) else: raise ValueError(f'Model: {config.model} not recognized.') return model diff --git a/examples/run_expt.py b/examples/run_expt.py index 173603ab..4dff02ef 100644 --- a/examples/run_expt.py +++ b/examples/run_expt.py @@ -155,7 +155,7 @@ def main(): transform_name=config.eval_transform, config=config, dataset=full_dataset) - + train_grouper = CombinatorialGrouper( dataset=full_dataset, groupby_fields=config.groupby_fields) diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index 76010d09..4a85c8a9 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -11,9 +11,9 @@ class EncodeTFBSDataset(WILDSDataset): """ - ENCODE-DREAM-wilds dataset of transcription factor binding sites. - This is a subset of the dataset from the ENCODE-DREAM in vivo Transcription Factor Binding Site Prediction Challenge. - + ENCODE-DREAM-wilds dataset of transcription factor binding sites. + This is a subset of the dataset from the ENCODE-DREAM in vivo Transcription Factor Binding Site Prediction Challenge. + Input (x): 12800-base-pair regions of sequence with a quantified chromatin accessibility readout. @@ -22,20 +22,24 @@ class EncodeTFBSDataset(WILDSDataset): Metadata: Each sequence is annotated with the celltype of origin (a string) and the chromosome of origin (a string). - + Website: https://www.synapse.org/#!Synapse:syn6131484 """ - def __init__(self, root_dir='data', download=False, split_scheme='official'): + _dataset_name = 'encode-tfbs' + _versions_dict = { + '1.0': { + 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x8b3255e21e164cd98d3aeec09cd0bc26/contents/blob/', + 'compressed_size': None}} + + def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'): itime = time.time() - self._dataset_name = 'encode-tfbs' - self._version = '1.0' - self._download_url = 'https://worksheets.codalab.org/rest/bundles/0x8b3255e21e164cd98d3aeec09cd0bc26/contents/blob/' + self._version = version self._data_dir = self.initialize_data_dir(root_dir, download) self._y_size = 128 # self._n_classes = 2 - + self._train_chroms = ['chr3']#, 'chr4', 'chr5', 'chr6', 'chr7', 'chr10', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX'] self._val_chroms = ['chr2']#, 'chr9', 'chr11'] self._test_chroms = ['chr1']#, 'chr8', 'chr21'] @@ -45,15 +49,15 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): self._test_celltype = ['GM12878'] self._all_chroms = self._train_chroms + self._val_chroms + self._test_chroms self._all_celltypes = self._train_celltypes + self._val_celltype + self._test_celltype - + self._metadata_map = {} self._metadata_map['chr'] = self._all_chroms self._metadata_map['celltype'] = self._all_celltypes - + # Get the splits if split_scheme=='official': split_scheme = 'standard' - + self._split_scheme = split_scheme self._split_dict = { 'train': 0, @@ -67,7 +71,7 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): 'test': 'Test', 'val': 'Validation (OOD)', } - + # Load sequence and DNase features sequence_filename = os.path.join(self._data_dir, 'sequence.npz') seq_arr = np.load(sequence_filename) @@ -75,7 +79,9 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): for chrom in self._all_chroms: #seq_arr: self._seq_bp[chrom] = seq_arr[chrom] print(chrom, time.time() - itime) - + + # Delete seq_arr? + self._dnase_allcelltypes = {} # ct = 'avg' # dnase_avg_bw_path = os.path.join(self._data_dir, 'DNase/{}.bigwig'.format(ct)) @@ -90,61 +96,61 @@ def __init__(self, root_dir='data', download=False, split_scheme='official'): """ dnase_bw_path = os.path.join(self._data_dir, 'DNase/{}.bigwig'.format(ct)) self._dnase_allcelltypes[ct] = pyBigWig.open(dnase_bw_path) - + self._metadata_df = pd.read_csv( - self._data_dir + '/labels/{}/metadata_df.bed'.format(self._transcription_factor), - sep='\t', header=None, + self._data_dir + '/labels/{}/metadata_df.bed'.format(self._transcription_factor), + sep='\t', header=None, index_col=None, names=['chr', 'start', 'stop', 'celltype'] ) - + train_regions_mask = np.isin(self._metadata_df['chr'], self._train_chroms) val_regions_mask = np.isin(self._metadata_df['chr'], self._val_chroms) test_regions_mask = np.isin(self._metadata_df['chr'], self._test_chroms) train_celltype_mask = np.isin(self._metadata_df['celltype'], self._train_celltypes) val_celltype_mask = np.isin(self._metadata_df['celltype'], self._val_celltype) test_celltype_mask = np.isin(self._metadata_df['celltype'], self._test_celltype) - + split_array = -1*np.ones(self._metadata_df.shape[0]).astype(int) split_array[np.logical_and(train_regions_mask, train_celltype_mask)] = self._split_dict['train'] split_array[np.logical_and(test_regions_mask, test_celltype_mask)] = self._split_dict['test'] # Validate using validation chr, either using a designated validation cell line ('val') or a training cell line ('id_val') split_array[np.logical_and(val_regions_mask, val_celltype_mask)] = self._split_dict['val'] split_array[np.logical_and(val_regions_mask, train_celltype_mask)] = self._split_dict['id_val'] - + if self._split_scheme=='standard': self._metadata_df.insert(len(self._metadata_df.columns), 'split', split_array) else: raise ValueError(f'Split scheme {self._split_scheme} not recognized') - + metadata_mask = (self._metadata_df['split'] != -1) self._metadata_df = self._metadata_df[self._metadata_df['split'] != -1] - + chr_ints = self._metadata_df['chr'].replace(dict( [(y, x) for x, y in enumerate(self._metadata_map['chr'])] )).values celltype_ints = self._metadata_df['celltype'].replace(dict( [(y, x) for x, y in enumerate(self._metadata_map['celltype'])] )).values self._split_array = self._metadata_df['split'].values self._y_array = torch.Tensor(np.load( self._data_dir + '/labels/{}/metadata_y.npy'.format(self._transcription_factor))) self._y_array = self._y_array[metadata_mask] - + self._metadata_array = torch.stack( - (torch.LongTensor(chr_ints), + (torch.LongTensor(chr_ints), torch.LongTensor(celltype_ints) ), dim=1) self._metadata_fields = ['chr', 'celltype'] - + self._eval_grouper = CombinatorialGrouper( dataset=self, groupby_fields=['celltype']) - + self._metric = MTAveragePrecision() - + super().__init__(root_dir, download, split_scheme) - + def get_input(self, idx, window_size=12800): """ Returns x for a given idx in metadata_array, which has been filtered to only take windows with the desired stride. - Computes this from: + Computes this from: (1) sequence features in self._seq_bp (2) DNase bigwig file handles in self._dnase_allcelltypes (3) Metadata for the index (location along the genome with 6400bp window width) @@ -160,7 +166,7 @@ def get_input(self, idx, window_size=12800): # print("{}:{}-{}".format(chrom, interval_start, interval_end)) # dnase_avg = self._dnase_allcelltypes['avg'].values(chrom, interval_start, interval_end, numpy=True) return torch.tensor(np.column_stack( - [np.nan_to_num(seq_this), + [np.nan_to_num(seq_this), np.nan_to_num(dnase_this)]#, np.nan_to_num(dnase_avg)] ).T) diff --git a/wilds/get_dataset.py b/wilds/get_dataset.py index 1073100f..a67ba538 100644 --- a/wilds/get_dataset.py +++ b/wilds/get_dataset.py @@ -55,7 +55,7 @@ def get_dataset(dataset, version=None, **dataset_kwargs): elif dataset == 'poverty': if version == '1.0': from wilds.datasets.archive.poverty_v1_0_dataset import PovertyMapDataset - else: + else: from wilds.datasets.poverty_dataset import PovertyMapDataset return PovertyMapDataset(version=version, **dataset_kwargs) @@ -77,3 +77,7 @@ def get_dataset(dataset, version=None, **dataset_kwargs): elif dataset == 'sqf': from wilds.datasets.sqf_dataset import SQFDataset return SQFDataset(version=version, **dataset_kwargs) + + elif dataset == 'encode-tfbs': + from wilds.datasets.encodetfbs_dataset import EncodeTFBSDataset + return EncodeTFBSDataset(version=version, **dataset_kwargs) From f033e6e4c1c7e72068ade6683d3b4ad5fdf9e6f9 Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Sun, 21 Mar 2021 18:17:55 -0700 Subject: [PATCH 087/244] fix MultiTaskAveragePrecision --- examples/configs/datasets.py | 6 ++--- examples/configs/supported.py | 4 ++-- wilds/common/metrics/all_metrics.py | 35 +++++++++++++++++++++------- wilds/common/metrics/metric.py | 6 ++++- wilds/datasets/encodetfbs_dataset.py | 5 ++-- 5 files changed, 39 insertions(+), 17 deletions(-) diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index 056b508e..7bc99584 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -113,12 +113,12 @@ 'eval_transform': None, 'loss_function': 'multitask_bce', 'groupby_fields': ['celltype'], - 'val_metric': 'avgprec-macro_all', + 'val_metric': 'avgprec-macro_all', 'val_metric_decreasing': False, 'optimizer': 'Adam', 'scheduler': None, - 'batch_size': 128, #64, - 'lr': 0.001, + 'batch_size': 64, #64, + 'lr': 0.0001, 'weight_decay': 0.01, 'n_epochs': 5, 'n_groups_per_batch': 2, diff --git a/examples/configs/supported.py b/examples/configs/supported.py index 10b3e361..4f2e43c0 100644 --- a/examples/configs/supported.py +++ b/examples/configs/supported.py @@ -17,8 +17,8 @@ 'accuracy': Accuracy(prediction_fn=multiclass_logits_to_pred), 'mse': MSE(), 'multitask_accuracy': MultiTaskAccuracy(prediction_fn=multiclass_logits_to_pred), - 'multitask_binary_accuracy': MultiTaskAccuracy(prediction_fn=binary_logits_to_pred), - 'multitask_avgprec': MTAveragePrecision(prediction_fn=binary_logits_to_pred), + 'multitask_binary_accuracy': MultiTaskAccuracy(prediction_fn=binary_logits_to_pred), + 'multitask_avgprec': MultiTaskAveragePrecision(prediction_fn=None), None: None, } diff --git a/wilds/common/metrics/all_metrics.py b/wilds/common/metrics/all_metrics.py index 55111775..5a6a86f4 100644 --- a/wilds/common/metrics/all_metrics.py +++ b/wilds/common/metrics/all_metrics.py @@ -4,7 +4,7 @@ import torch.nn.functional as F from wilds.common.metrics.metric import Metric, ElementwiseMetric, MultiTaskMetric from wilds.common.metrics.loss import ElementwiseLoss -from wilds.common.utils import avg_over_groups, minimum, maximum +from wilds.common.utils import avg_over_groups, minimum, maximum, get_counts import sklearn.metrics from scipy.stats import pearsonr @@ -19,7 +19,7 @@ def binary_logits_to_score(logits): def multiclass_logits_to_pred(logits): """ - Takes multi-class logits of size (batch_size, ..., n_classes) and returns predictions + Takes multi-class logits of size (batch_size, ..., n_classes) and returns predictions by taking an argmax at the last dimension """ assert logits.dim() > 1 @@ -74,13 +74,30 @@ def _compute_flattened(self, flattened_y_pred, flattened_y_true): ytr = np.array(flattened_y_true.squeeze().detach().cpu().numpy() > 0) ypr = flattened_y_pred.squeeze().detach().cpu().numpy() score = sklearn.metrics.average_precision_score( - ytr, - ypr, + ytr, + ypr, average=self.average ) to_ret = torch.tensor(score).to(flattened_y_pred.device) return to_ret - + + def _compute_group_wise(self, y_pred, y_true, g, n_groups): + group_metrics = [] + group_counts = get_counts(g, n_groups) + for group_idx in range(n_groups): + if group_counts[group_idx]==0: + group_metrics.append(torch.tensor(0., device=g.device)) + else: + flattened_metrics, _ = self.compute_flattened( + y_pred[g == group_idx], + y_true[g == group_idx], + return_dict=False) + group_metrics.append(flattened_metrics) + group_metrics = torch.stack(group_metrics) + worst_group_metric = self.worst(group_metrics[group_counts>0]) + + return group_metrics, group_counts, worst_group_metric + def _compute(self, y_pred, y_true): return self._compute_flattened(y_pred, y_true) @@ -120,8 +137,8 @@ def _compute(self, y_pred, y_true): if self.prediction_fn is not None: y_pred = self.prediction_fn(y_pred) score = sklearn.metrics.average_precision_score( - np.array(y_true.squeeze().detach().cpu().numpy() > 0), - y_pred.squeeze().detach().cpu().numpy(), + np.array(y_true.squeeze().detach().cpu().numpy() > 0), + y_pred.squeeze().detach().cpu().numpy(), average=self.average ) return torch.tensor(score) @@ -145,8 +162,8 @@ def _compute(self, y_pred, y_true): ytr = np.array(torch.flatten(y_true.squeeze()).detach().cpu().numpy() > 0) ypr = torch.flatten(y_pred.squeeze()).detach().cpu().numpy() score = sklearn.metrics.average_precision_score( - ytr, - ypr, + ytr, + ypr, average=self.average ) to_ret = torch.tensor(score).to(y_pred.device) diff --git a/wilds/common/metrics/metric.py b/wilds/common/metrics/metric.py index 3a628886..c1a6f9bf 100644 --- a/wilds/common/metrics/metric.py +++ b/wilds/common/metrics/metric.py @@ -233,7 +233,11 @@ def _compute_group_wise(self, y_pred, y_true, g, n_groups): flattened_metrics, indices = self.compute_flattened(y_pred, y_true, return_dict=False) flattened_g = g[indices] # print(flattened_metrics.shape, flattened_g.shape, (indices > 0).sum(), y_pred.shape, y_true.shape) - group_metrics, group_counts = avg_over_groups(flattened_metrics, flattened_g, n_groups) + try: + group_metrics, group_counts = avg_over_groups(flattened_metrics, flattened_g, n_groups) + except: + import IPython + IPython.embed() worst_group_metric = self.worst(group_metrics[group_counts>0]) return group_metrics, group_counts, worst_group_metric diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index 4a85c8a9..78367fae 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -5,7 +5,7 @@ import pyBigWig from wilds.datasets.wilds_dataset import WILDSDataset from wilds.common.grouper import CombinatorialGrouper -from wilds.common.metrics.all_metrics import Accuracy, MultiTaskAccuracy, MTAveragePrecision +from wilds.common.metrics.all_metrics import Accuracy, MultiTaskAccuracy, MTAveragePrecision, MultiTaskAveragePrecision all_chrom_names = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX'] @@ -143,7 +143,8 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' dataset=self, groupby_fields=['celltype']) - self._metric = MTAveragePrecision() + # self._metric = MTAveragePrecision() + self._metric = MultiTaskAveragePrecision() super().__init__(root_dir, download, split_scheme) From b9c1469348b00fe1032c51b6c0d0f4cb9aa5285e Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Sun, 21 Mar 2021 19:26:41 -0700 Subject: [PATCH 088/244] ambiguous label handling --- examples/configs/datasets.py | 5 +++-- wilds/datasets/encodetfbs_dataset.py | 15 ++++++++++++--- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index 7bc99584..0bea84a7 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -113,7 +113,7 @@ 'eval_transform': None, 'loss_function': 'multitask_bce', 'groupby_fields': ['celltype'], - 'val_metric': 'avgprec-macro_all', + 'val_metric': 'avgprec-macro_all', 'val_metric_decreasing': False, 'optimizer': 'Adam', 'scheduler': None, @@ -122,7 +122,8 @@ 'weight_decay': 0.01, 'n_epochs': 5, 'n_groups_per_batch': 2, - 'algo_log_metric': 'multitask_avgprec', + 'algo_log_metric': 'multitask_binary_accuracy', + # 'algo_log_metric': 'multitask_avgprec', # 'irm_lambda': 1.0, # 'coral_penalty_weight': 0.1, }, diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index 78367fae..5e30afbc 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -128,10 +128,16 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' chr_ints = self._metadata_df['chr'].replace(dict( [(y, x) for x, y in enumerate(self._metadata_map['chr'])] )).values celltype_ints = self._metadata_df['celltype'].replace(dict( [(y, x) for x, y in enumerate(self._metadata_map['celltype'])] )).values self._split_array = self._metadata_df['split'].values - self._y_array = torch.Tensor(np.load( + self._y_array = torch.tensor(np.load( self._data_dir + '/labels/{}/metadata_y.npy'.format(self._transcription_factor))) self._y_array = self._y_array[metadata_mask] + # ~10% of the dataset has ambiguous labels + # i.e., we can't tell if there is a binding event or not. + # This typically happens at the flanking regions of peaks. + # For our purposes, we will ignore these ambiguous labels during training and eval. + self.y_array[self.y_array == 0.5] = float('nan') + self._metadata_array = torch.stack( (torch.LongTensor(chr_ints), torch.LongTensor(celltype_ints) @@ -166,9 +172,12 @@ def get_input(self, idx, window_size=12800): dnase_this = dnase_bw.values(chrom, interval_start, interval_end, numpy=True) # print("{}:{}-{}".format(chrom, interval_start, interval_end)) # dnase_avg = self._dnase_allcelltypes['avg'].values(chrom, interval_start, interval_end, numpy=True) + + assert(np.isnan(seq_this).sum() == 0) + assert(np.isnan(dnase_this).sum() == 0) return torch.tensor(np.column_stack( - [np.nan_to_num(seq_this), - np.nan_to_num(dnase_this)]#, np.nan_to_num(dnase_avg)] + [seq_this, + dnase_this]#, np.nan_to_num(dnase_avg)] ).T) def eval(self, y_pred, y_true, metadata): From cf11a67e3186bfb6f134174f5a22b1fb91305e9c Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Sun, 21 Mar 2021 19:32:28 -0700 Subject: [PATCH 089/244] fix _compute for multitaskAP --- wilds/common/metrics/all_metrics.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/wilds/common/metrics/all_metrics.py b/wilds/common/metrics/all_metrics.py index 5a6a86f4..bcf9ad7b 100644 --- a/wilds/common/metrics/all_metrics.py +++ b/wilds/common/metrics/all_metrics.py @@ -98,8 +98,8 @@ def _compute_group_wise(self, y_pred, y_true, g, n_groups): return group_metrics, group_counts, worst_group_metric - def _compute(self, y_pred, y_true): - return self._compute_flattened(y_pred, y_true) + # def _compute(self, y_pred, y_true): + # return self._compute_flattened(y_pred, y_true) def worst(self, metrics): return minimum(metrics) From 5ba4d67fc3804d9f109958807c4fa8dd99152c17 Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Sun, 21 Mar 2021 20:59:22 -0700 Subject: [PATCH 090/244] revert config change for data loaders and cleanup --- examples/configs/data_loader.py | 4 ++-- examples/configs/supported.py | 4 ++-- wilds/common/metrics/metric.py | 7 +------ 3 files changed, 5 insertions(+), 10 deletions(-) diff --git a/examples/configs/data_loader.py b/examples/configs/data_loader.py index c00b1b64..8ddaa267 100644 --- a/examples/configs/data_loader.py +++ b/examples/configs/data_loader.py @@ -1,6 +1,6 @@ loader_defaults = { - 'loader_kwargs':{ - 'num_workers': 1, + 'loader_kwargs': { + 'num_workers': 4, 'pin_memory': True, }, 'n_groups_per_batch': 4, diff --git a/examples/configs/supported.py b/examples/configs/supported.py index 4f2e43c0..6c60ea53 100644 --- a/examples/configs/supported.py +++ b/examples/configs/supported.py @@ -18,7 +18,7 @@ 'mse': MSE(), 'multitask_accuracy': MultiTaskAccuracy(prediction_fn=multiclass_logits_to_pred), 'multitask_binary_accuracy': MultiTaskAccuracy(prediction_fn=binary_logits_to_pred), - 'multitask_avgprec': MultiTaskAveragePrecision(prediction_fn=None), + 'multitask_avgprec': MultiTaskAveragePrecision(prediction_fn=None), None: None, } @@ -32,7 +32,7 @@ transforms = ['bert', 'image_base', 'image_resize_and_center_crop', 'poverty_train'] models = ['resnet18_ms', 'resnet50', 'resnet34', 'wideresnet50', 'densenet121', 'bert-base-uncased', 'distilbert-base-uncased', - 'gin-virtual', 'logistic_regression', 'code-gpt-py', 'leopard'] + 'gin-virtual', 'logistic_regression', 'code-gpt-py', 'unet-seq'] algorithms = ['ERM', 'groupDRO', 'deepCORAL', 'IRM'] optimizers = ['SGD', 'Adam', 'AdamW'] schedulers = ['linear_schedule_with_warmup', 'ReduceLROnPlateau', 'StepLR'] diff --git a/wilds/common/metrics/metric.py b/wilds/common/metrics/metric.py index c1a6f9bf..9c4372b0 100644 --- a/wilds/common/metrics/metric.py +++ b/wilds/common/metrics/metric.py @@ -232,12 +232,7 @@ def _compute(self, y_pred, y_true): def _compute_group_wise(self, y_pred, y_true, g, n_groups): flattened_metrics, indices = self.compute_flattened(y_pred, y_true, return_dict=False) flattened_g = g[indices] - # print(flattened_metrics.shape, flattened_g.shape, (indices > 0).sum(), y_pred.shape, y_true.shape) - try: - group_metrics, group_counts = avg_over_groups(flattened_metrics, flattened_g, n_groups) - except: - import IPython - IPython.embed() + group_metrics, group_counts = avg_over_groups(flattened_metrics, flattened_g, n_groups) worst_group_metric = self.worst(group_metrics[group_counts>0]) return group_metrics, group_counts, worst_group_metric From 38815182e554683cfa8a8987dedb93dec87de1be Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Sun, 21 Mar 2021 21:56:51 -0700 Subject: [PATCH 091/244] grad_mode change for algorithm.train() and num_workers for encode --- examples/algorithms/algorithm.py | 2 -- examples/configs/datasets.py | 4 ++-- examples/train.py | 3 +++ 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/algorithms/algorithm.py b/examples/algorithms/algorithm.py index c93d960a..f19af3d8 100644 --- a/examples/algorithms/algorithm.py +++ b/examples/algorithms/algorithm.py @@ -39,14 +39,12 @@ def evaluate(self, batch): """ raise NotImplementedError - # Taken from domainbed def train(self, mode=True): """ Switch to train mode """ self.is_training = mode super().train(mode) - torch.set_grad_enabled(mode) self.reset_log() @property diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index 0bea84a7..4618a236 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -109,6 +109,7 @@ 'split_scheme': 'official', 'model': 'unet-seq', 'model_kwargs': {'n_channels_in': 5}, + 'loader_kwargs': {'num_workers': 1}, # pybigwig seems to have trouble with multiprocessing 'train_transform': None, 'eval_transform': None, 'loss_function': 'multitask_bce', @@ -117,13 +118,12 @@ 'val_metric_decreasing': False, 'optimizer': 'Adam', 'scheduler': None, - 'batch_size': 64, #64, + 'batch_size': 64, 'lr': 0.0001, 'weight_decay': 0.01, 'n_epochs': 5, 'n_groups_per_batch': 2, 'algo_log_metric': 'multitask_binary_accuracy', - # 'algo_log_metric': 'multitask_avgprec', # 'irm_lambda': 1.0, # 'coral_penalty_weight': 0.1, }, diff --git a/examples/train.py b/examples/train.py index 774f3d1e..596b69ba 100644 --- a/examples/train.py +++ b/examples/train.py @@ -11,8 +11,10 @@ def run_epoch(algorithm, dataset, general_logger, epoch, config, train): if train: algorithm.train() + torch.set_grad_enabled(True) else: algorithm.eval() + torch.set_grad_enabled(False) # Not preallocating memory is slower # but makes it easier to handle different types of data loaders @@ -114,6 +116,7 @@ def train(algorithm, datasets, general_logger, config, epoch_offset, best_val_me def evaluate(algorithm, datasets, epoch, general_logger, config): algorithm.eval() + torch.set_grad_enabled(False) for split, dataset in datasets.items(): if (not config.evaluate_all_splits) and (split not in config.eval_splits): continue From bcd082cddf71426dffe259d76cbf209e3627a0c5 Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Mon, 22 Mar 2021 10:58:10 -0700 Subject: [PATCH 092/244] in-dist baseline --- wilds/datasets/encodetfbs_dataset.py | 185 ++++++++++++++------------- 1 file changed, 97 insertions(+), 88 deletions(-) diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index 5e30afbc..b9d5c89a 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -5,9 +5,7 @@ import pyBigWig from wilds.datasets.wilds_dataset import WILDSDataset from wilds.common.grouper import CombinatorialGrouper -from wilds.common.metrics.all_metrics import Accuracy, MultiTaskAccuracy, MTAveragePrecision, MultiTaskAveragePrecision - -all_chrom_names = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX'] +from wilds.common.metrics.all_metrics import MultiTaskAveragePrecision class EncodeTFBSDataset(WILDSDataset): """ @@ -38,105 +36,119 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' self._version = version self._data_dir = self.initialize_data_dir(root_dir, download) self._y_size = 128 - # self._n_classes = 2 - - self._train_chroms = ['chr3']#, 'chr4', 'chr5', 'chr6', 'chr7', 'chr10', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX'] - self._val_chroms = ['chr2']#, 'chr9', 'chr11'] - self._test_chroms = ['chr1']#, 'chr8', 'chr21'] self._transcription_factor = 'MAX' - self._train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562'] - self._val_celltype = ['A549'] - self._test_celltype = ['GM12878'] - self._all_chroms = self._train_chroms + self._val_chroms + self._test_chroms - self._all_celltypes = self._train_celltypes + self._val_celltype + self._test_celltype + + # Read in metadata and labels + self._metadata_df = pd.read_csv( + self._data_dir + '/labels/{}/metadata_df.bed'.format(self._transcription_factor), + sep='\t', header=None, + index_col=None, names=['chr', 'start', 'stop', 'celltype'] + ) + self._y_array = torch.tensor(np.load( + self._data_dir + '/labels/{}/metadata_y.npy'.format(self._transcription_factor))) + + # ~10% of the dataset has ambiguous labels + # i.e., we can't tell if there is a binding event or not. + # This typically happens at the flanking regions of peaks. + # For our purposes, we will ignore these ambiguous labels during training and eval. + self.y_array[self.y_array == 0.5] = float('nan') + + # Construct splits + self._split_scheme = split_scheme + if self._split_scheme == 'official': + splits = { + 'train': { + 'chroms': ['chr3'], + 'celltypes': ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562'] + }, + 'id_val': { + 'chroms': ['chr2'], + 'celltypes': ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562'] + }, + 'val': { + 'chroms': ['chr2'], + 'celltypes': ['A549'] + }, + 'test': { + 'chroms': ['chr1'], + 'celltypes': ['GM12878'] + }, + } + self._split_dict = { + 'train': 0, + 'val': 1, + 'test': 2, + 'id_val': 3, + } + self._split_names = { + 'train': 'Train', + 'id_val': 'Validation (ID)', + 'test': 'Test', + 'val': 'Validation (OOD)', + } + elif self._split_scheme == 'in-dist': + splits = { + 'train': { + 'chroms': ['chr3'], + 'celltypes': ['GM12878'], + }, + 'val': { + 'chroms': ['chr2'], + 'celltypes': ['GM12878'] + }, + 'test': { + 'chroms': ['chr1'], + 'celltypes': ['GM12878'] + }, + } + self._split_dict = { + 'train': 0, + 'val': 1, + 'test': 2, + } + self._split_names = { + 'train': 'Train', + 'test': 'Test', + 'val': 'Validation (OOD)', + } + else: + raise ValueError(f'Split scheme {self._split_scheme} not recognized') + + self._split_array = -1 * np.ones(self._metadata_df.shape[0]).astype(int) + for split, d in splits.items(): + chrom_mask = np.isin(self._metadata_df['chr'], d['chroms']) + celltype_mask = np.isin(self._metadata_df['celltype'], d['celltypes']) + self._split_array[chrom_mask & celltype_mask] = self._split_dict[split] + + indices_to_keep = (self._split_array != -1) + self._metadata_df = self._metadata_df[indices_to_keep] + self._split_array = self._split_array[indices_to_keep] + self._y_array = self._y_array[indices_to_keep] + + self._all_chroms = sorted(list({chrom for _, d in splits.items() for chrom in d['chroms']})) + self._all_celltypes = sorted(list({chrom for _, d in splits.items() for chrom in d['celltypes']})) self._metadata_map = {} self._metadata_map['chr'] = self._all_chroms self._metadata_map['celltype'] = self._all_celltypes - # Get the splits - if split_scheme=='official': - split_scheme = 'standard' - - self._split_scheme = split_scheme - self._split_dict = { - 'train': 0, - 'id_val': 1, - 'test': 2, - 'val': 3 - } - self._split_names = { - 'train': 'Train', - 'id_val': 'Validation (ID)', - 'test': 'Test', - 'val': 'Validation (OOD)', - } - - # Load sequence and DNase features + # Load sequence into memory sequence_filename = os.path.join(self._data_dir, 'sequence.npz') seq_arr = np.load(sequence_filename) self._seq_bp = {} - for chrom in self._all_chroms: #seq_arr: + for chrom in self._all_chroms: self._seq_bp[chrom] = seq_arr[chrom] print(chrom, time.time() - itime) + del seq_arr - # Delete seq_arr? - + # Set up file handles for DNase features self._dnase_allcelltypes = {} - # ct = 'avg' - # dnase_avg_bw_path = os.path.join(self._data_dir, 'DNase/{}.bigwig'.format(ct)) - # self._dnase_allcelltypes[ct] = pyBigWig.open(dnase_avg_bw_path) for ct in self._all_celltypes: - """ - dnase_filename = os.path.join(self._data_dir, '{}_dnase.npz'.format(ct)) - dnase_npz_contents = np.load(dnase_filename) - self._dnase_allcelltypes[ct] = {} - for chrom in self._all_chroms: #self._seq_bp: - self._dnase_allcelltypes[ct][chrom] = dnase_npz_contents[chrom] - """ dnase_bw_path = os.path.join(self._data_dir, 'DNase/{}.bigwig'.format(ct)) self._dnase_allcelltypes[ct] = pyBigWig.open(dnase_bw_path) - self._metadata_df = pd.read_csv( - self._data_dir + '/labels/{}/metadata_df.bed'.format(self._transcription_factor), - sep='\t', header=None, - index_col=None, names=['chr', 'start', 'stop', 'celltype'] - ) - - train_regions_mask = np.isin(self._metadata_df['chr'], self._train_chroms) - val_regions_mask = np.isin(self._metadata_df['chr'], self._val_chroms) - test_regions_mask = np.isin(self._metadata_df['chr'], self._test_chroms) - train_celltype_mask = np.isin(self._metadata_df['celltype'], self._train_celltypes) - val_celltype_mask = np.isin(self._metadata_df['celltype'], self._val_celltype) - test_celltype_mask = np.isin(self._metadata_df['celltype'], self._test_celltype) - - split_array = -1*np.ones(self._metadata_df.shape[0]).astype(int) - split_array[np.logical_and(train_regions_mask, train_celltype_mask)] = self._split_dict['train'] - split_array[np.logical_and(test_regions_mask, test_celltype_mask)] = self._split_dict['test'] - # Validate using validation chr, either using a designated validation cell line ('val') or a training cell line ('id_val') - split_array[np.logical_and(val_regions_mask, val_celltype_mask)] = self._split_dict['val'] - split_array[np.logical_and(val_regions_mask, train_celltype_mask)] = self._split_dict['id_val'] - - if self._split_scheme=='standard': - self._metadata_df.insert(len(self._metadata_df.columns), 'split', split_array) - else: - raise ValueError(f'Split scheme {self._split_scheme} not recognized') - - metadata_mask = (self._metadata_df['split'] != -1) - self._metadata_df = self._metadata_df[self._metadata_df['split'] != -1] - chr_ints = self._metadata_df['chr'].replace(dict( [(y, x) for x, y in enumerate(self._metadata_map['chr'])] )).values celltype_ints = self._metadata_df['celltype'].replace(dict( [(y, x) for x, y in enumerate(self._metadata_map['celltype'])] )).values - self._split_array = self._metadata_df['split'].values - self._y_array = torch.tensor(np.load( - self._data_dir + '/labels/{}/metadata_y.npy'.format(self._transcription_factor))) - self._y_array = self._y_array[metadata_mask] - - # ~10% of the dataset has ambiguous labels - # i.e., we can't tell if there is a binding event or not. - # This typically happens at the flanking regions of peaks. - # For our purposes, we will ignore these ambiguous labels during training and eval. - self.y_array[self.y_array == 0.5] = float('nan') self._metadata_array = torch.stack( (torch.LongTensor(chr_ints), @@ -149,7 +161,6 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' dataset=self, groupby_fields=['celltype']) - # self._metric = MTAveragePrecision() self._metric = MultiTaskAveragePrecision() super().__init__(root_dir, download, split_scheme) @@ -166,18 +177,16 @@ def get_input(self, idx, window_size=12800): this_metadata = self._metadata_df.iloc[idx, :] chrom = this_metadata['chr'] interval_start = this_metadata['start'] - int(window_size/4) - interval_end = interval_start + window_size #this_metadata['stop'] + interval_end = interval_start + window_size seq_this = self._seq_bp[this_metadata['chr']][interval_start:interval_end] dnase_bw = self._dnase_allcelltypes[this_metadata['celltype']] dnase_this = dnase_bw.values(chrom, interval_start, interval_end, numpy=True) - # print("{}:{}-{}".format(chrom, interval_start, interval_end)) - # dnase_avg = self._dnase_allcelltypes['avg'].values(chrom, interval_start, interval_end, numpy=True) assert(np.isnan(seq_this).sum() == 0) assert(np.isnan(dnase_this).sum() == 0) return torch.tensor(np.column_stack( [seq_this, - dnase_this]#, np.nan_to_num(dnase_avg)] + dnase_this] ).T) def eval(self, y_pred, y_true, metadata): From dd1013f09d54bfa3a62bae42cec21f79ff36ec49 Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Mon, 22 Mar 2021 11:22:38 -0700 Subject: [PATCH 093/244] reorder keys --- wilds/datasets/encodetfbs_dataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index b9d5c89a..c3136c33 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -82,9 +82,9 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' } self._split_names = { 'train': 'Train', - 'id_val': 'Validation (ID)', - 'test': 'Test', 'val': 'Validation (OOD)', + 'test': 'Test', + 'id_val': 'Validation (ID)', } elif self._split_scheme == 'in-dist': splits = { @@ -108,8 +108,8 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' } self._split_names = { 'train': 'Train', - 'test': 'Test', 'val': 'Validation (OOD)', + 'test': 'Test', } else: raise ValueError(f'Split scheme {self._split_scheme} not recognized') From a1bde7de19b73df7dfac13b974190dee466dbc93 Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Mon, 22 Mar 2021 11:23:23 -0700 Subject: [PATCH 094/244] move metadata code --- wilds/datasets/encodetfbs_dataset.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index c3136c33..395ddc78 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -128,10 +128,6 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' self._all_chroms = sorted(list({chrom for _, d in splits.items() for chrom in d['chroms']})) self._all_celltypes = sorted(list({chrom for _, d in splits.items() for chrom in d['celltypes']})) - self._metadata_map = {} - self._metadata_map['chr'] = self._all_chroms - self._metadata_map['celltype'] = self._all_celltypes - # Load sequence into memory sequence_filename = os.path.join(self._data_dir, 'sequence.npz') seq_arr = np.load(sequence_filename) @@ -156,6 +152,9 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' ), dim=1) self._metadata_fields = ['chr', 'celltype'] + self._metadata_map = {} + self._metadata_map['chr'] = self._all_chroms + self._metadata_map['celltype'] = self._all_celltypes self._eval_grouper = CombinatorialGrouper( dataset=self, From d1644413d3ccd6a65768169ac9f4b9e268b4465c Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Mon, 22 Mar 2021 11:23:42 -0700 Subject: [PATCH 095/244] move metadata code --- wilds/datasets/encodetfbs_dataset.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index 395ddc78..4c62f74f 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -143,9 +143,6 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' dnase_bw_path = os.path.join(self._data_dir, 'DNase/{}.bigwig'.format(ct)) self._dnase_allcelltypes[ct] = pyBigWig.open(dnase_bw_path) - chr_ints = self._metadata_df['chr'].replace(dict( [(y, x) for x, y in enumerate(self._metadata_map['chr'])] )).values - celltype_ints = self._metadata_df['celltype'].replace(dict( [(y, x) for x, y in enumerate(self._metadata_map['celltype'])] )).values - self._metadata_array = torch.stack( (torch.LongTensor(chr_ints), torch.LongTensor(celltype_ints) @@ -155,6 +152,8 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' self._metadata_map = {} self._metadata_map['chr'] = self._all_chroms self._metadata_map['celltype'] = self._all_celltypes + chr_ints = self._metadata_df['chr'].replace(dict( [(y, x) for x, y in enumerate(self._metadata_map['chr'])] )).values + celltype_ints = self._metadata_df['celltype'].replace(dict( [(y, x) for x, y in enumerate(self._metadata_map['celltype'])] )).values self._eval_grouper = CombinatorialGrouper( dataset=self, From 165d335e3cb1638c1e8b3117adcb04669045f52f Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Mon, 22 Mar 2021 11:24:26 -0700 Subject: [PATCH 096/244] metadata reorg + comment --- wilds/datasets/encodetfbs_dataset.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index 4c62f74f..6ca6a828 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -143,17 +143,18 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' dnase_bw_path = os.path.join(self._data_dir, 'DNase/{}.bigwig'.format(ct)) self._dnase_allcelltypes[ct] = pyBigWig.open(dnase_bw_path) - self._metadata_array = torch.stack( - (torch.LongTensor(chr_ints), - torch.LongTensor(celltype_ints) - ), - dim=1) + # Set up metadata fields, map, array self._metadata_fields = ['chr', 'celltype'] self._metadata_map = {} self._metadata_map['chr'] = self._all_chroms self._metadata_map['celltype'] = self._all_celltypes chr_ints = self._metadata_df['chr'].replace(dict( [(y, x) for x, y in enumerate(self._metadata_map['chr'])] )).values celltype_ints = self._metadata_df['celltype'].replace(dict( [(y, x) for x, y in enumerate(self._metadata_map['celltype'])] )).values + self._metadata_array = torch.stack( + (torch.LongTensor(chr_ints), + torch.LongTensor(celltype_ints) + ), + dim=1) self._eval_grouper = CombinatorialGrouper( dataset=self, From 1ab46307eed664bf1db2d976c175af688a48d94a Mon Sep 17 00:00:00 2001 From: aikanor Date: Mon, 22 Mar 2021 13:49:19 -0700 Subject: [PATCH 097/244] minor refactoring --- sbox_run_expt.ipynb | 585 --------------------------- wilds/datasets/encodetfbs_dataset.py | 34 +- 2 files changed, 20 insertions(+), 599 deletions(-) delete mode 100644 sbox_run_expt.ipynb diff --git a/sbox_run_expt.ipynb b/sbox_run_expt.ipynb deleted file mode 100644 index 39f5e862..00000000 --- a/sbox_run_expt.ipynb +++ /dev/null @@ -1,585 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# run_expt.py contents" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "ename": "SyntaxError", - "evalue": "invalid syntax (version.py, line 20)", - "output_type": "error", - "traceback": [ - "\u001b[0;36m File \u001b[0;32m\"wilds/version.py\"\u001b[0;36m, line \u001b[0;32m20\u001b[0m\n\u001b[0;31m f'The WILDS package is out of date. Your version is {__version__}, while the latest version is {latest}.')\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m invalid syntax\n" - ] - } - ], - "source": [ - "import os, csv\n", - "import time\n", - "import argparse\n", - "import pandas as pd\n", - "import torch\n", - "import torch.nn as nn\n", - "import torchvision\n", - "import sys\n", - "from collections import defaultdict\n", - "\n", - "from wilds.common.data_loaders import get_train_loader, get_eval_loader\n", - "from wilds.common.grouper import CombinatorialGrouper\n", - "\n", - "from utils import set_seed, Logger, BatchLogger, log_config, ParseKwargs, load, initialize_wandb, log_group_data, parse_bool\n", - "from train import train, evaluate\n", - "from algorithms.initializer import initialize_algorithm\n", - "from transforms import initialize_transform\n", - "from configs.utils import populate_defaults\n", - "import configs.supported as supported" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Initialize dataset object" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "57.8772239685\n", - "66.8270189762\n" - ] - } - ], - "source": [ - "import numpy as np, pandas as pd, os, time, torch, torchvision\n", - "data_dir = '/oak/stanford/groups/akundaje/abalsubr/DREAM/wilds/codalab_archive/'\n", - "tf = 'MAX'\n", - "itime = time.time()\n", - "train_chr = pd.read_csv(os.path.join(data_dir, 'labels/{}.train.labels.tsv.gz'.format(tf)), sep='\\t')\n", - "print(time.time() - itime)\n", - "val_chr = pd.read_csv(os.path.join(data_dir, 'labels/{}.val.labels.tsv.gz'.format(tf)), sep='\\t')\n", - "print(time.time() - itime)" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']\n", - "val_celltype = ['A549']\n", - "test_celltype = ['GM12878']\n", - "all_celltypes = train_celltypes + val_celltype + test_celltype\n", - "\n", - "metadata_map = {}\n", - "metadata_map['chr'] = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX']\n", - "metadata_map['celltype'] = all_celltypes\n", - "\n", - "_split_dict = {\n", - " 'train': 0,\n", - " 'val-id': 1,\n", - " 'test': 2,\n", - " 'val-ood': 3\n", - "}\n", - "_split_names = {\n", - " 'train': 'Train',\n", - " 'val-id': 'Validation (ID)',\n", - " 'test': 'Test',\n", - " 'val-ood': 'Validation (OOD)'\n", - "}\n", - "_split_scheme = 'standard'" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "('H1-hESC', 25.299736976623535)\n", - "('HCT116', 49.68733310699463)\n", - "('HeLa-S3', 74.65905213356018)\n", - "('HepG2', 99.33112812042236)\n", - "('K562', 124.1327919960022)\n", - "('A549', 149.19999814033508)\n", - "('GM12878', 174.0277030467987)\n" - ] - } - ], - "source": [ - "itime = time.time()\n", - "sequence_filename = os.path.join(data_dir, 'sequence.npz')\n", - "seq_arr = np.load(sequence_filename)\n", - "print(time.time() - itime)\n", - "\n", - "itime = time.time()\n", - "_seq_bp = {}\n", - "for chrom in seq_arr:\n", - " _seq_bp[chrom] = seq_arr[chrom]\n", - " print(chrom, time.time() - itime)\n", - "itime = time.time()\n", - "_dnase_allcelltypes = {}\n", - "for ct in all_celltypes:\n", - " dnase_filename = os.path.join(data_dir, '{}_dnase.npz'.format(ct))\n", - " dnase_npz_file = np.load(dnase_filename)\n", - " _dnase_allcelltypes[ct] = {}\n", - " for chrom in _seq_bp:\n", - " _dnase_allcelltypes[ct][chrom] = dnase_npz_file[chrom]\n", - " print(ct, time.time() - itime)" - ] - }, - { - "cell_type": "code", - "execution_count": 78, - "metadata": {}, - "outputs": [], - "source": [ - "import math\n", - "import numpy as np\n", - "import torch\n", - "import torch.nn as nn\n", - "import torch.nn.functional as F\n", - "\n", - "class Beagle(nn.Module):\n", - " \"\"\"\n", - " Neural net models over genomic sequence.\n", - " Input:\n", - " - sequence_length: int (default 1000) \n", - " - Shape: (N, 5, sequence_length, 1) with batch size N.\n", - " \n", - " Output:\n", - " - prediction (Tensor): float torch tensor of shape (N, )\n", - " \n", - " TODO: Finish docstring.\n", - " \"\"\"\n", - " def __init__(self):\n", - " \"\"\"\n", - " Parameters\n", - " ----------\n", - " sequence_length : int\n", - " n_genomic_features : int\n", - " \"\"\"\n", - " super(Beagle, self).__init__()\n", - "\n", - " self.dropout = 0.3\n", - " self.num_cell_types = 1\n", - " self.conv1 = nn.Conv2d(5, 300, (19, 1), stride = (1, 1), padding=(9,0))\n", - " self.conv2 = nn.Conv2d(300, 200, (11, 1), stride = (1, 1), padding = (5,0))\n", - " self.conv3 = nn.Conv2d(200, 200, (7, 1), stride = (1, 1), padding = (4,0))\n", - " self.bn1 = nn.BatchNorm2d(300)\n", - " self.bn2 = nn.BatchNorm2d(200)\n", - " self.bn3 = nn.BatchNorm2d(200)\n", - " self.maxpool1 = nn.MaxPool2d((3, 1))\n", - " self.maxpool2 = nn.MaxPool2d((4, 1))\n", - " self.maxpool3 = nn.MaxPool2d((4, 1))\n", - "\n", - " self.fc1 = nn.Linear(4200, 1000)\n", - " self.bn4 = nn.BatchNorm1d(1000)\n", - "\n", - " self.fc2 = nn.Linear(1000, 1000)\n", - " self.bn5 = nn.BatchNorm1d(1000)\n", - "\n", - " self.fc3 = nn.Linear(1000, self.num_cell_types)\n", - "\n", - " def forward(self, s):\n", - " s = s.permute(0, 2, 1).contiguous() # batch_size x 5 x 1000\n", - " s = s.view(-1, 5, 1000, 1) # batch_size x 5 x 1000 x 1 [5 channels]\n", - " s = self.maxpool1(F.relu(self.bn1(self.conv1(s)))) # batch_size x 300 x 333 x 1\n", - " s = self.maxpool2(F.relu(self.bn2(self.conv2(s)))) # batch_size x 200 x 83 x 1\n", - " s = self.maxpool3(F.relu(self.bn3(self.conv3(s)))) # batch_size x 200 x 21 x 1\n", - " s = s.view(-1, 4200)\n", - " conv_out = s\n", - "\n", - " s = F.dropout(F.relu(self.bn4(self.fc1(s))), p=self.dropout, training=self.training) # batch_size x 1000\n", - " s = F.dropout(F.relu(self.bn5(self.fc2(s))), p=self.dropout, training=self.training) # batch_size x 1000\n", - " \n", - " s = self.fc3(s)\n", - "\n", - " return s, conv_out" - ] - }, - { - "cell_type": "code", - "execution_count": 86, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[('nnet.0.weight', 33280),\n", - " ('nnet.0.bias', 320),\n", - " ('bdlstm.0.weight_ih_l0', 409600),\n", - " ('bdlstm.0.weight_hh_l0', 409600),\n", - " ('bdlstm.0.bias_ih_l0', 1280),\n", - " ('bdlstm.0.bias_hh_l0', 1280),\n", - " ('bdlstm.0.weight_ih_l0_reverse', 409600),\n", - " ('bdlstm.0.weight_hh_l0_reverse', 409600),\n", - " ('bdlstm.0.bias_ih_l0_reverse', 1280),\n", - " ('bdlstm.0.bias_hh_l0_reverse', 1280),\n", - " ('classifier.1.weight', 592000),\n", - " ('classifier.1.bias', 925),\n", - " ('classifier.3.weight', 4625),\n", - " ('classifier.3.bias', 5)]" - ] - }, - "execution_count": 86, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "def count_parameters(model):\n", - " return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", - "\n", - "model = Beagle2()\n", - "model = DanQ(50, 5)\n", - "\n", - "lst = [(x[0], x[1].numel()) for x in model.named_parameters()]\n", - "#np.sum([x[1] for x in lst])\n", - "count_parameters(model)\n", - "lst" - ] - }, - { - "cell_type": "code", - "execution_count": 48, - "metadata": {}, - "outputs": [ - { - "ename": "AttributeError", - "evalue": "'module' object has no attribute 'isin'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mtr_chrs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m'chr2'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr9'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr11'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mte_chrs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m'chr1'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr8'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr21'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mtraining_df\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'chr'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtr_chrs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mval_df\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mval_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mval_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'chr'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mte_chrs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mall_df\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtraining_df\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_df\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mAttributeError\u001b[0m: 'module' object has no attribute 'isin'" - ] - } - ], - "source": [ - "tr_chrs = ['chr2', 'chr9', 'chr11']\n", - "te_chrs = ['chr1', 'chr8', 'chr21']\n", - "training_df = train_chr[np.isin(train_chr['chr'], tr_chrs)]\n", - "val_df = val_chr[np.isin(val_chr['chr'], te_chrs)]\n", - "all_df = pd.concat([training_df, val_df])\n", - "\n", - "#filter_msk = all_df['start'] >= 0\n", - "filter_msk = all_df['start']%1000 == 0\n", - "all_df = all_df[filter_msk]" - ] - }, - { - "cell_type": "code", - "execution_count": 49, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "1.12.1\n" - ] - } - ], - "source": [ - "print(np.__version__)" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/users/abalsubr/anaconda2/envs/scs3/lib/python3.6/site-packages/ipykernel_launcher.py:6: SettingWithCopyWarning: \n", - "A value is trying to be set on a copy of a slice from a DataFrame.\n", - "Try using .loc[row_indexer,col_indexer] = value instead\n", - "\n", - "See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy\n", - " \n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "1.659163236618042\n" - ] - } - ], - "source": [ - "itime = time.time()\n", - "pd_list = []\n", - "for ct in all_celltypes:\n", - " tc_chr = all_df[['chr', 'start', 'stop', ct]]\n", - " tc_chr.columns = ['chr', 'start', 'stop', 'y']\n", - " tc_chr['celltype'] = ct\n", - " pd_list.append(tc_chr)\n", - "metadata_df = pd.concat(pd_list)\n", - "print(time.time() - itime)" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "3.0391879081726074\n" - ] - } - ], - "source": [ - "itime = time.time()\n", - "y_array = metadata_df['y'].replace({'U': 0, 'B': 1, 'A': -1}).values\n", - "non_ambig_mask = (y_array != -1)\n", - "metadata_df['y'] = y_array\n", - "_metadata_df = metadata_df[non_ambig_mask]\n", - "_y_array = torch.LongTensor(y_array[non_ambig_mask])\n", - "print(time.time() - itime)" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "12.390011310577393\n" - ] - } - ], - "source": [ - "itime = time.time()\n", - "chr_ints = _metadata_df['chr'].replace(dict( [(y, x) for x, y in enumerate(metadata_map['chr'])] )).values\n", - "celltype_ints = _metadata_df['celltype'].replace(dict( [(y, x) for x, y in enumerate(metadata_map['celltype'])] )).values\n", - "print(time.time() - itime)" - ] - }, - { - "cell_type": "code", - "execution_count": 53, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/users/abalsubr/anaconda2/envs/scs3/lib/python3.6/site-packages/ipykernel_launcher.py:12: SettingWithCopyWarning: \n", - "A value is trying to be set on a copy of a slice from a DataFrame.\n", - "Try using .loc[row_indexer,col_indexer] = value instead\n", - "\n", - "See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy\n", - " if sys.path[0] == '':\n" - ] - } - ], - "source": [ - "train_chr_mask = np.isin(_metadata_df['chr'], tr_chrs)\n", - "val_chr_mask = np.isin(_metadata_df['chr'], te_chrs)\n", - "train_celltype_mask = np.isin(_metadata_df['celltype'], train_celltypes)\n", - "val_celltype_mask = np.isin(_metadata_df['celltype'], val_celltype)\n", - "test_celltype_mask = np.isin(_metadata_df['celltype'], test_celltype)\n", - "\n", - "split_array = -1*np.ones(_metadata_df.shape[0]).astype(int)\n", - "split_array[np.logical_and(train_chr_mask, train_celltype_mask)] = _split_dict['train']\n", - "split_array[np.logical_and(val_chr_mask, test_celltype_mask)] = _split_dict['test']\n", - "split_array[np.logical_and(val_chr_mask, val_celltype_mask)] = _split_dict['val-ood']\n", - "split_array[np.logical_and(val_chr_mask, train_celltype_mask)] = _split_dict['val-id']\n", - "_metadata_df['split'] = split_array\n", - "_split_array = split_array" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# get_input (idx)" - ] - }, - { - "cell_type": "code", - "execution_count": 153, - "metadata": {}, - "outputs": [], - "source": [ - "idx = 3\n", - "this_metadata = _metadata_df.iloc[idx, :]\n", - "\n", - "itime = time.time()\n", - "flank_size = 400\n", - "interval_start = this_metadata['start'] - flank_size\n", - "interval_end = this_metadata['stop'] + flank_size\n", - "dnase_this = _dnase_allcelltypes[this_metadata['celltype']][this_metadata['chr']][interval_start:interval_end]\n", - "seq_this = _seq_bp[this_metadata['chr']][interval_start:interval_end]\n", - "data = np.column_stack([seq_this, dnase_this])\n", - "# print(time.time() - itime)" - ] - }, - { - "cell_type": "code", - "execution_count": 154, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "4600" - ] - }, - "execution_count": 154, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "data.shape\n", - "interval_end\n", - "# itime = time.time()\n", - "# np.save(os.path.join(data_dir, 'stmp.npy'), sa)\n", - "# print(time.time() - itime)" - ] - }, - { - "cell_type": "code", - "execution_count": 78, - "metadata": {}, - "outputs": [ - { - "ename": "TypeError", - "evalue": "can't convert np.ndarray of type numpy.object_. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool.", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mitime\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m metadata_array = torch.stack(\n\u001b[0;32m----> 3\u001b[0;31m (torch.LongTensor(metadata_df['chr'].values), \n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mLongTensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmetadata_df\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'celltype'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m self._y_array),\n", - "\u001b[0;31mTypeError\u001b[0m: can't convert np.ndarray of type numpy.object_. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool." - ] - } - ], - "source": [ - "itime = time.time()\n", - "metadata_array = torch.stack(\n", - " (torch.LongTensor(chr_ints), \n", - " torch.LongTensor(celltype_ints), \n", - " _y_array),\n", - " dim=1)\n", - "print(time.time() - itime)" - ] - }, - { - "cell_type": "code", - "execution_count": 156, - "metadata": {}, - "outputs": [ - { - "ename": "NameError", - "evalue": "name '_metadata_array' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0m_metadata_array\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;31mNameError\u001b[0m: name '_metadata_array' is not defined" - ] - } - ], - "source": [ - "_metadata_array" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "from examples.models.model_attributes import model_attributes" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.5" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index 6ca6a828..90de3245 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -54,24 +54,30 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' self.y_array[self.y_array == 0.5] = float('nan') # Construct splits + train_chroms = ['chr3']#, 'chr4', 'chr5', 'chr6', 'chr7', 'chr10', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX'] + val_chroms = ['chr2']#, 'chr9', 'chr11'] + test_chroms = ['chr1']#, 'chr8', 'chr21'] + train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562'] + val_celltype = ['A549'] + test_celltype = ['GM12878'] self._split_scheme = split_scheme if self._split_scheme == 'official': splits = { 'train': { - 'chroms': ['chr3'], - 'celltypes': ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562'] + 'chroms': train_chroms, + 'celltypes': train_celltypes }, 'id_val': { - 'chroms': ['chr2'], - 'celltypes': ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562'] + 'chroms': val_chroms, + 'celltypes': train_celltypes }, 'val': { - 'chroms': ['chr2'], - 'celltypes': ['A549'] + 'chroms': val_chroms, + 'celltypes': val_celltype }, 'test': { - 'chroms': ['chr1'], - 'celltypes': ['GM12878'] + 'chroms': test_chroms, + 'celltypes': test_celltype }, } self._split_dict = { @@ -89,16 +95,16 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' elif self._split_scheme == 'in-dist': splits = { 'train': { - 'chroms': ['chr3'], - 'celltypes': ['GM12878'], + 'chroms': train_chroms, + 'celltypes': test_celltype, }, 'val': { - 'chroms': ['chr2'], - 'celltypes': ['GM12878'] + 'chroms': val_chroms, + 'celltypes': test_celltype }, 'test': { - 'chroms': ['chr1'], - 'celltypes': ['GM12878'] + 'chroms': test_chroms, + 'celltypes': test_celltype }, } self._split_dict = { From e66af2e1a63448275b9d57af6a90c271557dd63e Mon Sep 17 00:00:00 2001 From: aikanor Date: Mon, 22 Mar 2021 14:24:24 -0700 Subject: [PATCH 098/244] refactoring eval 1/2 --- .../encode-tfbs/prep_metadata_labels.ipynb | 98 ++++++++++++++++--- .../encode-tfbs/prep_metadata_labels.py | 83 ++++++++++++++++ 2 files changed, 170 insertions(+), 11 deletions(-) create mode 100644 dataset_preprocessing/encode-tfbs/prep_metadata_labels.py diff --git a/dataset_preprocessing/encode-tfbs/prep_metadata_labels.ipynb b/dataset_preprocessing/encode-tfbs/prep_metadata_labels.ipynb index 9748bd25..2bfe5729 100644 --- a/dataset_preprocessing/encode-tfbs/prep_metadata_labels.ipynb +++ b/dataset_preprocessing/encode-tfbs/prep_metadata_labels.ipynb @@ -12,16 +12,18 @@ "import pyBigWig\n", "\n", "# Human chromosome names\n", - "chr_IDs = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX']" + "chr_IDs = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX']\n", + "_data_dir = '../../examples/data/encode-tfbs_v1.0/'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Prep metadata df and metadata array surrounding the labels\n", + "# Prep metadata df and metadata/label array\n", "- Metadata df contains 6400bp (window_size/2) prediction windows across the genome. Each gets a 128-bit prediction from the model.\n", - "- We store the ones that aren't fully unbound. All the rest are fully unbound." + "- We store the ones that aren't fully unbound, and write these to bigwigs representing genome-wide labels.\n", + "- Then read from the bigwigs to make a metadata dataframe." ] }, { @@ -49,7 +51,6 @@ "source": [ "itime = time.time()\n", "\n", - "_data_dir = '../../examples/data/encode-tfbs_v1.0/'\n", "_transcription_factor = 'MAX'\n", "_train_chroms = ['chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr10', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX']\n", "_val_chroms = ['chr2', 'chr9', 'chr11']\n", @@ -98,6 +99,13 @@ ")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Write the binned labels to bigwig files - genome-wide labels" + ] + }, { "cell_type": "code", "execution_count": 6, @@ -130,6 +138,13 @@ " bw.close()" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Read from genome-wide label bigwigs to generate metadata for the bound sites" + ] + }, { "cell_type": "code", "execution_count": 7, @@ -166,6 +181,10 @@ "source": [ "stride = 6400\n", "itime = time.time()\n", + "mdf_posamb = pd.read_csv(\n", + " _sorted_dir, \n", + " sep='\\t', header=None, index_col=None, names=['chr', 'start', 'stop', 'y', 'celltype']\n", + ")\n", "celltype_mdta = []\n", "celltype_labels = []\n", "\n", @@ -205,9 +224,70 @@ "print(time.time() - itime)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Add the all-unbound sites" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "mdf = pd.read_csv(\n", + " _data_dir + 'labels/MAX/metadata_df.bed', sep='\\t', header=None, \n", + " index_col=None, names=['chr', 'start', 'stop', 'celltype']\n", + ")\n", + "mdy = np.load(_data_dir + 'labels/MAX/metadata_y.npy')" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(array([0. , 0.5, 1. ], dtype=float32), array([57034243, 7888122, 2118659]))" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.unique(mdy, return_counts=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([ 0, 6400, 12800, ..., 249203200, 249216000,\n", + " 249222400])" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.unique(mdf['start'])" + ] + }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 17, "metadata": {}, "outputs": [ { @@ -337,17 +417,13 @@ "[523758 rows x 4 columns]" ] }, - "execution_count": 19, + "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "pd.read_csv(\n", - " _data_dir + 'labels/MAX/metadata_df.bed', sep='\\t', header=None, \n", - " index_col=None, names=['chr', 'start', 'stop', 'celltype']\n", - ")\n", - "# np.load(_data_dir + 'labels/MAX/metadata_y.npy')" + "mdf" ] }, { diff --git a/dataset_preprocessing/encode-tfbs/prep_metadata_labels.py b/dataset_preprocessing/encode-tfbs/prep_metadata_labels.py new file mode 100644 index 00000000..a12aaa7b --- /dev/null +++ b/dataset_preprocessing/encode-tfbs/prep_metadata_labels.py @@ -0,0 +1,83 @@ +import os, csv +import scipy, numpy as np, pandas as pd, time +from scipy import sparse +import pyBigWig + +# Human chromosome names +chr_IDs = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX'] +chrom_sizes = {'chr1': 249250621, 'chr10': 135534747, 'chr11': 135006516, 'chr12': 133851895, 'chr13': 115169878, 'chr14': 107349540, 'chr15': 102531392, 'chr16': 90354753, 'chr17': 81195210, 'chr18': 78077248, 'chr19': 59128983, 'chr2': 243199373, 'chr20': 63025520, 'chr21': 48129895, 'chr22': 51304566, 'chr3': 198022430, 'chr4': 191154276, 'chr5': 180915260, 'chr6': 171115067, 'chr7': 159138663, 'chr8': 146364022, 'chr9': 141213431, 'chrX': 155270560} + +_data_dir = '../../examples/data/encode-tfbs_v1.0/' + + +def write_label_bigwigs(): + itime = time.time() + transcription_factor = 'MAX' + _train_chroms = ['chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr10', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX'] + _val_chroms = ['chr2', 'chr9', 'chr11'] + _test_chroms = ['chr1', 'chr8', 'chr21'] + _all_chroms = _train_chroms + _val_chroms + _test_chroms + _train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562'] + _val_celltype = ['A549'] + _test_celltype = ['GM12878'] + _all_celltypes = _train_celltypes + _val_celltype + _test_celltype + + # Read in metadata dataframe from training+validation data + train_regions_labeled = pd.read_csv(os.path.join(_data_dir, 'labels/{}.train.labels.tsv.gz'.format(_transcription_factor)), sep='\t') + val_regions_labeled = pd.read_csv(os.path.join(_data_dir, 'labels/{}.val.labels.tsv.gz'.format(_transcription_factor)), sep='\t') + training_df = train_regions_labeled# [np.isin(train_regions_labeled['chr'], _train_chroms)] + val_df = val_regions_labeled# [np.isin(val_regions_labeled['chr'], _test_chroms)] + all_df = pd.concat([training_df, val_df]) + + print(time.time() - itime) + + # Get the y values, and remove labels by default. + pd_list = [] + for ct in _all_celltypes: + tc_chr = all_df[['chr', 'start', 'stop', ct]] + tc_chr.columns = ['chr', 'start', 'stop', 'y'] + tc_chr = tc_chr[tc_chr['y'] != 'U'] + tc_chr['y'] = tc_chr['y'].replace({'U': 0, 'B': 1, 'A': 0.5}).values + + tc_chr.insert(len(tc_chr.columns), 'celltype', ct) + pd_list.append(tc_chr) + print(ct, time.time() - itime) + _metadata_df = pd.concat(pd_list) + + print(time.time() - itime) + _unsorted_dir = _data_dir + 'labels/{}/{}_posamb.bed'.format( + transcription_factor, transcription_factor) + _sorted_dir = _unsorted_dir.replace( + '{}_posamb'.format(transcription_factor), + '{}_posamb.sorted'.format(transcription_factor) + ) + _metadata_df.to_csv( + _unsorted_dir, sep='\t', header=False, index=False + ) + print(time.time() - itime) + + os.system('sort -k1,1 -k2,2n {} > {}'.format(_unsorted_dir, _sorted_dir)) + + mdf_posamb = pd.read_csv( + _sorted_dir, + sep='\t', header=None, index_col=None, names=['chr', 'start', 'stop', 'y', 'celltype'] + ) + + # Write the binned labels to bigwig files - genome-wide labels + chromsizes_list = [(k, v) for k, v in chrom_sizes.items()] + for ct in _all_celltypes: + ct_labels_bw_path = _data_dir + "labels/{}/{}_{}.bigwig".format( + transcription_factor, transcription_factor, ct) + df = mdf_posamb[mdf_posamb['celltype'] == ct] + bw = pyBigWig.open(ct_labels_bw_path, "w") + bw.addHeader(chromsizes_list) + bw.addEntries(list(df['chr']), list(df['start']), ends=list(df['start']+50), values=list(df['y'])) + print(ct, time.time() - itime) + bw.close() + + +if __name__ == '__main__': + write_label_bigwigs() + generate_accessibility_archives( + input_dir=args.input_dir, + output_dir=args.output_dir) \ No newline at end of file From 16178146945e21654be9c0d2b9d604b780bd3992 Mon Sep 17 00:00:00 2001 From: aikanor Date: Mon, 22 Mar 2021 15:51:12 -0700 Subject: [PATCH 099/244] refactoring eval 1.5/2 --- .../encode-tfbs/prep_metadata_labels.ipynb | 377 ++++++++++-------- .../encode-tfbs/prep_metadata_labels.py | 4 +- 2 files changed, 208 insertions(+), 173 deletions(-) diff --git a/dataset_preprocessing/encode-tfbs/prep_metadata_labels.ipynb b/dataset_preprocessing/encode-tfbs/prep_metadata_labels.ipynb index 2bfe5729..235f0600 100644 --- a/dataset_preprocessing/encode-tfbs/prep_metadata_labels.ipynb +++ b/dataset_preprocessing/encode-tfbs/prep_metadata_labels.ipynb @@ -2,14 +2,14 @@ "cells": [ { "cell_type": "code", - "execution_count": 3, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os, csv\n", "import scipy, numpy as np, pandas as pd, time\n", "from scipy import sparse\n", - "import pyBigWig\n", + "import pyBigWig, prep_metadata_labels\n", "\n", "# Human chromosome names\n", "chr_IDs = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX']\n", @@ -22,14 +22,18 @@ "source": [ "# Prep metadata df and metadata/label array\n", "- Metadata df contains 6400bp (window_size/2) prediction windows across the genome. Each gets a 128-bit prediction from the model.\n", - "- We store the ones that aren't fully unbound, and write these to bigwigs representing genome-wide labels.\n", - "- Then read from the bigwigs to make a metadata dataframe." + "- We store the ones that aren't fully unbound, and write these to bigwigs representing genome-wide labels." ] }, { "cell_type": "code", "execution_count": 4, - "metadata": {}, + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + } + }, "outputs": [ { "name": "stdout", @@ -49,100 +53,14 @@ } ], "source": [ - "itime = time.time()\n", - "\n", - "_transcription_factor = 'MAX'\n", - "_train_chroms = ['chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr10', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX']\n", - "_val_chroms = ['chr2', 'chr9', 'chr11']\n", - "_test_chroms = ['chr1', 'chr8', 'chr21']\n", - "_all_chroms = _train_chroms + _val_chroms + _test_chroms\n", - "_train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']\n", - "_val_celltype = ['A549']\n", - "_test_celltype = ['GM12878']\n", - "_all_celltypes = _train_celltypes + _val_celltype + _test_celltype\n", - "\n", - "# Read in metadata dataframe from training+validation data\n", - "train_regions_labeled = pd.read_csv(os.path.join(_data_dir, 'labels/{}.train.labels.tsv.gz'.format(_transcription_factor)), sep='\\t')\n", - "val_regions_labeled = pd.read_csv(os.path.join(_data_dir, 'labels/{}.val.labels.tsv.gz'.format(_transcription_factor)), sep='\\t')\n", - "training_df = train_regions_labeled# [np.isin(train_regions_labeled['chr'], _train_chroms)]\n", - "val_df = val_regions_labeled# [np.isin(val_regions_labeled['chr'], _test_chroms)]\n", - "all_df = pd.concat([training_df, val_df])\n", - "\n", - "print(time.time() - itime)\n", - "\n", - "# Get the y values, and remove labels by default.\n", - "pd_list = []\n", - "for ct in _all_celltypes:\n", - " tc_chr = all_df[['chr', 'start', 'stop', ct]]\n", - " tc_chr.columns = ['chr', 'start', 'stop', 'y']\n", - " tc_chr = tc_chr[tc_chr['y'] != 'U']\n", - " tc_chr['y'] = tc_chr['y'].replace({'U': 0, 'B': 1, 'A': 0.5}).values\n", - " \n", - " tc_chr.insert(len(tc_chr.columns), 'celltype', ct)\n", - " pd_list.append(tc_chr)\n", - " print(ct, time.time() - itime)\n", - "_metadata_df = pd.concat(pd_list)\n", - "\n", - "print(time.time() - itime)\n", - "_unsorted_dir = _data_dir + 'labels/MAX/MAX_posamb.bed'\n", - "_sorted_dir = _unsorted_dir.replace('MAX_posamb', 'MAX_posamb.sorted')\n", - "_metadata_df.to_csv(\n", - " _unsorted_dir, sep='\\t', header=False, index=False\n", - ")\n", - "print(time.time() - itime)\n", - "\n", - "os.system('sort -k1,1 -k2,2n {} > {}'.format(_unsorted_dir, _sorted_dir))\n", - "\n", - "mdf_posamb = pd.read_csv(\n", - " _sorted_dir, \n", - " sep='\\t', header=None, index_col=None, names=['chr', 'start', 'stop', 'y', 'celltype']\n", - ")" + "prep_metadata_labels.write_label_bigwigs()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Write the binned labels to bigwig files - genome-wide labels" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "H1-hESC 350.84476041793823\n", - "HCT116 358.2693498134613\n", - "HeLa-S3 364.6210968494415\n", - "HepG2 372.65956830978394\n", - "K562 380.6701240539551\n", - "A549 388.50364875793457\n", - "GM12878 394.2549338340759\n" - ] - } - ], - "source": [ - "chrom_sizes = {'chr1': 249250621, 'chr10': 135534747, 'chr11': 135006516, 'chr12': 133851895, 'chr13': 115169878, 'chr14': 107349540, 'chr15': 102531392, 'chr16': 90354753, 'chr17': 81195210, 'chr18': 78077248, 'chr19': 59128983, 'chr2': 243199373, 'chr20': 63025520, 'chr21': 48129895, 'chr22': 51304566, 'chr3': 198022430, 'chr4': 191154276, 'chr5': 180915260, 'chr6': 171115067, 'chr7': 159138663, 'chr8': 146364022, 'chr9': 141213431, 'chrX': 155270560}\n", - "chromsizes_list = [(k, v) for k, v in chrom_sizes.items()]\n", - "for ct in _all_celltypes:\n", - " ct_labels_bw_path = _data_dir + \"labels/MAX/MAX_{}.bigwig\".format(ct)\n", - " df = mdf_posamb[mdf_posamb['celltype'] == ct]\n", - " bw = pyBigWig.open(ct_labels_bw_path, \"w\")\n", - " bw.addHeader(chromsizes_list)\n", - " bw.addEntries(list(df['chr']), list(df['start']), ends=list(df['start']+50), values=list(df['y']))\n", - " print(ct, time.time() - itime)\n", - " bw.close()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Read from genome-wide label bigwigs to generate metadata for the bound sites" + "- Then read from the bigwigs to generate metadata for the bound sites." ] }, { @@ -233,61 +151,180 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "H1-hESC chr1 31.691891193389893\n", + "H1-hESC chr10 43.88507628440857\n", + "H1-hESC chr11 54.64318251609802\n", + "H1-hESC chr12 63.76666021347046\n", + "H1-hESC chr13 72.60888147354126\n", + "H1-hESC chr14 78.53658175468445\n", + "H1-hESC chr15 84.56542801856995\n", + "H1-hESC chr16 92.28407764434814\n", + "H1-hESC chr17 99.54330348968506\n", + "H1-hESC chr18 106.55353927612305\n", + "H1-hESC chr19 111.41691207885742\n", + "H1-hESC chr2 135.3123984336853\n", + "H1-hESC chr20 141.9123089313507\n", + "H1-hESC chr21 146.14480471611023\n", + "H1-hESC chr22 150.8621871471405\n", + "H1-hESC chr3 169.92432117462158\n", + "H1-hESC chr4 186.69121527671814\n", + "H1-hESC chr5 201.6394476890564\n", + "H1-hESC chr6 215.72684383392334\n", + "H1-hESC chr7 227.8461310863495\n", + "H1-hESC chr8 240.26825499534607\n", + "H1-hESC chr9 250.02118062973022\n", + "H1-hESC chrX 264.7451572418213\n", + "H1-hESC 267.0940718650818\n", + "HCT116 chr1 291.3232545852661\n", + "HCT116 chr10 304.0528976917267\n", + "HCT116 chr11 316.63377356529236\n", + "HCT116 chr12 329.559387922287\n", + "HCT116 chr13 341.52057003974915\n", + "HCT116 chr14 350.64817333221436\n", + "HCT116 chr15 359.71765422821045\n", + "HCT116 chr16 368.28583669662476\n", + "HCT116 chr17 376.01680874824524\n", + "HCT116 chr18 383.57749581336975\n", + "HCT116 chr19 389.313095331192\n", + "HCT116 chr2 412.09550070762634\n", + "HCT116 chr20 417.96392583847046\n", + "HCT116 chr21 422.00927662849426\n", + "HCT116 chr22 426.71226167678833\n", + "HCT116 chr3 442.76402711868286\n", + "HCT116 chr4 461.9360821247101\n", + "HCT116 chr5 478.7654387950897\n", + "HCT116 chr6 495.16735339164734\n", + "HCT116 chr7 511.82248401641846\n", + "HCT116 chr8 525.3609001636505\n", + "HCT116 chr9 538.2295203208923\n", + "HCT116 chrX 553.2177627086639\n", + "HCT116 555.6791486740112\n", + "HeLa-S3 chr1 580.0471041202545\n", + "HeLa-S3 chr10 594.5216126441956\n", + "HeLa-S3 chr11 606.1479568481445\n", + "HeLa-S3 chr12 618.3873989582062\n", + "HeLa-S3 chr13 628.7777881622314\n", + "HeLa-S3 chr14 637.718688249588\n", + "HeLa-S3 chr15 647.1737523078918\n", + "HeLa-S3 chr16 655.4854662418365\n", + "HeLa-S3 chr17 662.5001983642578\n", + "HeLa-S3 chr18 671.1846849918365\n", + "HeLa-S3 chr19 677.9798579216003\n", + "HeLa-S3 chr2 700.6258955001831\n", + "HeLa-S3 chr20 706.6806621551514\n", + "HeLa-S3 chr21 710.3620142936707\n", + "HeLa-S3 chr22 714.1444058418274\n", + "HeLa-S3 chr3 733.4964163303375\n", + "HeLa-S3 chr4 751.4331798553467\n", + "HeLa-S3 chr5 768.5986630916595\n" + ] + } + ], + "source": [ + "stride = 6400\n", + "itime = time.time()\n", + "mdf_posamb = pd.read_csv(\n", + " _data_dir + 'labels/MAX/MAX_posamb.sorted.bed', \n", + " sep='\\t', header=None, index_col=None, names=['chr', 'start', 'stop', 'y', 'celltype']\n", + ")\n", + "celltype_mdta = []\n", + "celltype_labels = []\n", + "_train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']\n", + "_val_celltype = ['A549']\n", + "_test_celltype = ['GM12878']\n", + "_all_celltypes = _train_celltypes + _val_celltype + _test_celltype\n", + "\n", + "for ct in _all_celltypes:\n", + " ct_labels_bw_path = _data_dir + \"labels/MAX/MAX_{}.bigwig\".format(ct)\n", + " df_construction = []\n", + " mdta_labels = []\n", + " bw = pyBigWig.open(ct_labels_bw_path)\n", + " for chrID in bw.chroms():\n", + " chromsize = bw.chroms()[chrID]\n", + " # Iterate over windows\n", + " for startc in np.arange(0, chromsize, stride):\n", + " u_end = startc + stride\n", + " if u_end > chromsize:\n", + " break\n", + " x = np.nan_to_num(bw.values(chrID, startc, u_end, numpy=True))\n", + " df_construction.append((chrID, startc, u_end))\n", + " mdta_labels.append(x[np.arange(0, len(x), 50)])\n", + " print(ct, chrID, time.time() - itime)\n", + " celltype_mdta_df = pd.DataFrame(df_construction, columns=['chr', 'start', 'stop'])\n", + " celltype_mdta_df.insert(len(celltype_mdta_df.columns), 'celltype', ct)\n", + " celltype_mdta.append(celltype_mdta_df)\n", + " celltype_labels.append(np.stack(mdta_labels))\n", + " print(ct, time.time() - itime)\n", + " bw.close()\n", + " # break\n", + "print(time.time() - itime)" + ] + }, + { + "cell_type": "code", + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "mdf = pd.read_csv(\n", - " _data_dir + 'labels/MAX/metadata_df.bed', sep='\\t', header=None, \n", - " index_col=None, names=['chr', 'start', 'stop', 'celltype']\n", + "all_metadata_df = pd.concat(celltype_mdta)\n", + "print(time.time() - itime)\n", + "all_metadata_df.to_csv(\n", + " _data_dir + 'labels/MAX/all_metadata_df.bed', \n", + " sep='\\t', header=False, index=False\n", ")\n", - "mdy = np.load(_data_dir + 'labels/MAX/metadata_y.npy')" + "print(time.time() - itime)\n", + "np.save(_data_dir + 'labels/MAX/all_metadata_y.npy', np.vstack(celltype_labels))\n", + "print(time.time() - itime)" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 16, "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "(array([0. , 0.5, 1. ], dtype=float32), array([57034243, 7888122, 2118659]))" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "1609.5078670978546\n" + ] } ], "source": [ - "np.unique(mdy, return_counts=True)" + "np.save(_data_dir + 'labels/MAX/all_metadata_y.npy', np.vstack(celltype_labels))\n", + "print(time.time() - itime)" ] }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "array([ 0, 6400, 12800, ..., 249203200, 249216000,\n", - " 249222400])" + "(169827, 128)" ] }, - "execution_count": 16, + "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "np.unique(mdf['start'])" + "np.vstack(celltype_labels).shape" ] }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 20, "metadata": {}, "outputs": [ { @@ -320,38 +357,38 @@ " \n", " \n", " 0\n", - " chr10\n", - " 100025600\n", - " 100032000\n", - " A549\n", + " chrX\n", + " 0\n", + " 6400\n", + " GM12878\n", " \n", " \n", " 1\n", - " chr10\n", - " 100032000\n", - " 100038400\n", - " A549\n", + " chrX\n", + " 6400\n", + " 12800\n", + " GM12878\n", " \n", " \n", " 2\n", - " chr10\n", - " 100064000\n", - " 100070400\n", - " A549\n", + " chrX\n", + " 12800\n", + " 19200\n", + " GM12878\n", " \n", " \n", " 3\n", - " chr10\n", - " 100076800\n", - " 100083200\n", - " A549\n", + " chrX\n", + " 19200\n", + " 25600\n", + " GM12878\n", " \n", " \n", " 4\n", - " chr10\n", - " 100083200\n", - " 100089600\n", - " A549\n", + " chrX\n", + " 25600\n", + " 32000\n", + " GM12878\n", " \n", " \n", " ...\n", @@ -361,69 +398,69 @@ " ...\n", " \n", " \n", - " 523753\n", + " 24256\n", " chrX\n", - " 99699200\n", - " 99705600\n", - " K562\n", + " 155238400\n", + " 155244800\n", + " GM12878\n", " \n", " \n", - " 523754\n", + " 24257\n", " chrX\n", - " 9977600\n", - " 9984000\n", - " K562\n", + " 155244800\n", + " 155251200\n", + " GM12878\n", " \n", " \n", - " 523755\n", + " 24258\n", " chrX\n", - " 99904000\n", - " 99910400\n", - " K562\n", + " 155251200\n", + " 155257600\n", + " GM12878\n", " \n", " \n", - " 523756\n", + " 24259\n", " chrX\n", - " 99923200\n", - " 99929600\n", - " K562\n", + " 155257600\n", + " 155264000\n", + " GM12878\n", " \n", " \n", - " 523757\n", + " 24260\n", " chrX\n", - " 99993600\n", - " 100000000\n", - " K562\n", + " 155264000\n", + " 155270400\n", + " GM12878\n", " \n", " \n", "\n", - "

523758 rows × 4 columns

\n", + "

24261 rows × 4 columns

\n", "" ], "text/plain": [ - " chr start stop celltype\n", - "0 chr10 100025600 100032000 A549\n", - "1 chr10 100032000 100038400 A549\n", - "2 chr10 100064000 100070400 A549\n", - "3 chr10 100076800 100083200 A549\n", - "4 chr10 100083200 100089600 A549\n", - "... ... ... ... ...\n", - "523753 chrX 99699200 99705600 K562\n", - "523754 chrX 9977600 9984000 K562\n", - "523755 chrX 99904000 99910400 K562\n", - "523756 chrX 99923200 99929600 K562\n", - "523757 chrX 99993600 100000000 K562\n", + " chr start stop celltype\n", + "0 chrX 0 6400 GM12878\n", + "1 chrX 6400 12800 GM12878\n", + "2 chrX 12800 19200 GM12878\n", + "3 chrX 19200 25600 GM12878\n", + "4 chrX 25600 32000 GM12878\n", + "... ... ... ... ...\n", + "24256 chrX 155238400 155244800 GM12878\n", + "24257 chrX 155244800 155251200 GM12878\n", + "24258 chrX 155251200 155257600 GM12878\n", + "24259 chrX 155257600 155264000 GM12878\n", + "24260 chrX 155264000 155270400 GM12878\n", "\n", - "[523758 rows x 4 columns]" + "[24261 rows x 4 columns]" ] }, - "execution_count": 17, + "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "mdf" + "celltype_mdta_df" ] }, { diff --git a/dataset_preprocessing/encode-tfbs/prep_metadata_labels.py b/dataset_preprocessing/encode-tfbs/prep_metadata_labels.py index a12aaa7b..de156920 100644 --- a/dataset_preprocessing/encode-tfbs/prep_metadata_labels.py +++ b/dataset_preprocessing/encode-tfbs/prep_metadata_labels.py @@ -78,6 +78,4 @@ def write_label_bigwigs(): if __name__ == '__main__': write_label_bigwigs() - generate_accessibility_archives( - input_dir=args.input_dir, - output_dir=args.output_dir) \ No newline at end of file + \ No newline at end of file From 9b7d93c1eee24ca947f7a7a768e875a770d827ed Mon Sep 17 00:00:00 2001 From: aikanor Date: Mon, 22 Mar 2021 16:09:15 -0700 Subject: [PATCH 100/244] eval 2/2 --- .../encode-tfbs/prep_metadata_labels.ipynb | 305 +++++++----------- 1 file changed, 116 insertions(+), 189 deletions(-) diff --git a/dataset_preprocessing/encode-tfbs/prep_metadata_labels.ipynb b/dataset_preprocessing/encode-tfbs/prep_metadata_labels.ipynb index 235f0600..ac20f433 100644 --- a/dataset_preprocessing/encode-tfbs/prep_metadata_labels.ipynb +++ b/dataset_preprocessing/encode-tfbs/prep_metadata_labels.ipynb @@ -151,7 +151,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 22, "metadata": {}, "outputs": [ { @@ -223,7 +223,110 @@ "HeLa-S3 chr22 714.1444058418274\n", "HeLa-S3 chr3 733.4964163303375\n", "HeLa-S3 chr4 751.4331798553467\n", - "HeLa-S3 chr5 768.5986630916595\n" + "HeLa-S3 chr5 768.5986630916595\n", + "HeLa-S3 chr6 784.270875453949\n", + "HeLa-S3 chr7 797.38880443573\n", + "HeLa-S3 chr8 810.2063493728638\n", + "HeLa-S3 chr9 822.0966618061066\n", + "HeLa-S3 chrX 836.4143695831299\n", + "HeLa-S3 838.8340592384338\n", + "HepG2 chr1 859.2195272445679\n", + "HepG2 chr10 871.4736864566803\n", + "HepG2 chr11 882.3986892700195\n", + "HepG2 chr12 892.8185703754425\n", + "HepG2 chr13 902.5888631343842\n", + "HepG2 chr14 910.1683735847473\n", + "HepG2 chr15 917.9265463352203\n", + "HepG2 chr16 924.9525022506714\n", + "HepG2 chr17 931.0623359680176\n", + "HepG2 chr18 936.3413846492767\n", + "HepG2 chr19 940.5808892250061\n", + "HepG2 chr2 956.4598441123962\n", + "HepG2 chr20 961.2141344547272\n", + "HepG2 chr21 964.7532434463501\n", + "HepG2 chr22 967.2651414871216\n", + "HepG2 chr3 979.6278064250946\n", + "HepG2 chr4 992.0893275737762\n", + "HepG2 chr5 1003.5635616779327\n", + "HepG2 chr6 1013.5618009567261\n", + "HepG2 chr7 1023.0335354804993\n", + "HepG2 chr8 1031.6841459274292\n", + "HepG2 chr9 1040.966101884842\n", + "HepG2 chrX 1050.8653886318207\n", + "HepG2 1052.4600455760956\n", + "K562 chr1 1067.425773382187\n", + "K562 chr10 1077.6573045253754\n", + "K562 chr11 1090.1371166706085\n", + "K562 chr12 1097.8005549907684\n", + "K562 chr13 1104.0740525722504\n", + "K562 chr14 1110.0484874248505\n", + "K562 chr15 1115.4218137264252\n", + "K562 chr16 1120.9185173511505\n", + "K562 chr17 1125.8319385051727\n", + "K562 chr18 1130.6154806613922\n", + "K562 chr19 1134.5980072021484\n", + "K562 chr2 1149.40651845932\n", + "K562 chr20 1153.4179229736328\n", + "K562 chr21 1156.8773469924927\n", + "K562 chr22 1159.5331387519836\n", + "K562 chr3 1172.5848636627197\n", + "K562 chr4 1184.034875869751\n", + "K562 chr5 1194.7483167648315\n", + "K562 chr6 1205.1025590896606\n", + "K562 chr7 1215.1975507736206\n", + "K562 chr8 1224.603568315506\n", + "K562 chr9 1233.0638110637665\n", + "K562 chrX 1243.617464542389\n", + "K562 1245.0139937400818\n", + "A549 chr1 1261.9037923812866\n", + "A549 chr10 1269.7828676700592\n", + "A549 chr11 1277.8243072032928\n", + "A549 chr12 1285.5817420482635\n", + "A549 chr13 1292.7439670562744\n", + "A549 chr14 1299.0479788780212\n", + "A549 chr15 1305.467554807663\n", + "A549 chr16 1310.6993942260742\n", + "A549 chr17 1315.2565250396729\n", + "A549 chr18 1320.7803528308868\n", + "A549 chr19 1324.2634809017181\n", + "A549 chr2 1342.3504185676575\n", + "A549 chr20 1346.4802606105804\n", + "A549 chr21 1349.3959574699402\n", + "A549 chr22 1352.0359740257263\n", + "A549 chr3 1363.596797466278\n", + "A549 chr4 1377.4319243431091\n", + "A549 chr5 1389.1430621147156\n", + "A549 chr6 1399.9876585006714\n", + "A549 chr7 1410.3144631385803\n", + "A549 chr8 1419.1995024681091\n", + "A549 chr9 1429.0530500411987\n", + "A549 chrX 1438.9812471866608\n", + "A549 1442.8687179088593\n", + "GM12878 chr1 1460.053512096405\n", + "GM12878 chr10 1467.7110249996185\n", + "GM12878 chr11 1477.1048283576965\n", + "GM12878 chr12 1486.3174769878387\n", + "GM12878 chr13 1493.3219420909882\n", + "GM12878 chr14 1499.3263096809387\n", + "GM12878 chr15 1505.1299676895142\n", + "GM12878 chr16 1511.5330748558044\n", + "GM12878 chr17 1516.295937538147\n", + "GM12878 chr18 1521.195916891098\n", + "GM12878 chr19 1524.6872880458832\n", + "GM12878 chr2 1540.392737865448\n", + "GM12878 chr20 1544.6384494304657\n", + "GM12878 chr21 1547.7403523921967\n", + "GM12878 chr22 1550.2115426063538\n", + "GM12878 chr3 1563.652869939804\n", + "GM12878 chr4 1575.806304693222\n", + "GM12878 chr5 1587.0206112861633\n", + "GM12878 chr6 1596.6215176582336\n", + "GM12878 chr7 1606.828714132309\n", + "GM12878 chr8 1616.3425333499908\n", + "GM12878 chr9 1624.599442243576\n", + "GM12878 chrX 1634.0690217018127\n", + "GM12878 1635.526027917862\n", + "1635.5274765491486\n" ] } ], @@ -269,206 +372,30 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "all_metadata_df = pd.concat(celltype_mdta)\n", - "print(time.time() - itime)\n", - "all_metadata_df.to_csv(\n", - " _data_dir + 'labels/MAX/all_metadata_df.bed', \n", - " sep='\\t', header=False, index=False\n", - ")\n", - "print(time.time() - itime)\n", - "np.save(_data_dir + 'labels/MAX/all_metadata_y.npy', np.vstack(celltype_labels))\n", - "print(time.time() - itime)" - ] - }, - { - "cell_type": "code", - "execution_count": 16, + "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "1609.5078670978546\n" + "1635.9203071594238\n", + "1644.0572729110718\n", + "1665.5485808849335\n" ] } ], "source": [ + "all_metadata_df = pd.concat(celltype_mdta)\n", + "print(time.time() - itime)\n", + "all_metadata_df.to_csv(\n", + " _data_dir + 'labels/MAX/all_metadata_df.bed', \n", + " sep='\\t', header=False, index=False\n", + ")\n", + "print(time.time() - itime)\n", "np.save(_data_dir + 'labels/MAX/all_metadata_y.npy', np.vstack(celltype_labels))\n", "print(time.time() - itime)" ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(169827, 128)" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "np.vstack(celltype_labels).shape" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
chrstartstopcelltype
0chrX06400GM12878
1chrX640012800GM12878
2chrX1280019200GM12878
3chrX1920025600GM12878
4chrX2560032000GM12878
...............
24256chrX155238400155244800GM12878
24257chrX155244800155251200GM12878
24258chrX155251200155257600GM12878
24259chrX155257600155264000GM12878
24260chrX155264000155270400GM12878
\n", - "

24261 rows × 4 columns

\n", - "
" - ], - "text/plain": [ - " chr start stop celltype\n", - "0 chrX 0 6400 GM12878\n", - "1 chrX 6400 12800 GM12878\n", - "2 chrX 12800 19200 GM12878\n", - "3 chrX 19200 25600 GM12878\n", - "4 chrX 25600 32000 GM12878\n", - "... ... ... ... ...\n", - "24256 chrX 155238400 155244800 GM12878\n", - "24257 chrX 155244800 155251200 GM12878\n", - "24258 chrX 155251200 155257600 GM12878\n", - "24259 chrX 155257600 155264000 GM12878\n", - "24260 chrX 155264000 155270400 GM12878\n", - "\n", - "[24261 rows x 4 columns]" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "celltype_mdta_df" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { From 7599caf79ca678f9d974e3e2b2b75663fcf3a1b7 Mon Sep 17 00:00:00 2001 From: aikanor Date: Mon, 22 Mar 2021 22:57:18 -0700 Subject: [PATCH 101/244] eval+metadata fixes --- .../encode-tfbs/prep_metadata_labels.ipynb | 528 ++++++++++++------ .../encode-tfbs/prep_metadata_labels.py | 46 ++ wilds/datasets/encodetfbs_dataset.py | 2 +- 3 files changed, 399 insertions(+), 177 deletions(-) diff --git a/dataset_preprocessing/encode-tfbs/prep_metadata_labels.ipynb b/dataset_preprocessing/encode-tfbs/prep_metadata_labels.ipynb index ac20f433..b4040638 100644 --- a/dataset_preprocessing/encode-tfbs/prep_metadata_labels.ipynb +++ b/dataset_preprocessing/encode-tfbs/prep_metadata_labels.ipynb @@ -105,6 +105,10 @@ ")\n", "celltype_mdta = []\n", "celltype_labels = []\n", + "_train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']\n", + "_val_celltype = ['A549']\n", + "_test_celltype = ['GM12878']\n", + "_all_celltypes = _train_celltypes + _val_celltype + _test_celltype\n", "\n", "for ct in _all_celltypes:\n", " ct_labels_bw_path = _data_dir + \"labels/MAX/MAX_{}.bigwig\".format(ct)\n", @@ -151,182 +155,187 @@ }, { "cell_type": "code", - "execution_count": 22, - "metadata": {}, + "execution_count": 3, + "metadata": { + "collapsed": true, + "jupyter": { + "outputs_hidden": true + } + }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "H1-hESC chr1 31.691891193389893\n", - "H1-hESC chr10 43.88507628440857\n", - "H1-hESC chr11 54.64318251609802\n", - "H1-hESC chr12 63.76666021347046\n", - "H1-hESC chr13 72.60888147354126\n", - "H1-hESC chr14 78.53658175468445\n", - "H1-hESC chr15 84.56542801856995\n", - "H1-hESC chr16 92.28407764434814\n", - "H1-hESC chr17 99.54330348968506\n", - "H1-hESC chr18 106.55353927612305\n", - "H1-hESC chr19 111.41691207885742\n", - "H1-hESC chr2 135.3123984336853\n", - "H1-hESC chr20 141.9123089313507\n", - "H1-hESC chr21 146.14480471611023\n", - "H1-hESC chr22 150.8621871471405\n", - "H1-hESC chr3 169.92432117462158\n", - "H1-hESC chr4 186.69121527671814\n", - "H1-hESC chr5 201.6394476890564\n", - "H1-hESC chr6 215.72684383392334\n", - "H1-hESC chr7 227.8461310863495\n", - "H1-hESC chr8 240.26825499534607\n", - "H1-hESC chr9 250.02118062973022\n", - "H1-hESC chrX 264.7451572418213\n", - "H1-hESC 267.0940718650818\n", - "HCT116 chr1 291.3232545852661\n", - "HCT116 chr10 304.0528976917267\n", - "HCT116 chr11 316.63377356529236\n", - "HCT116 chr12 329.559387922287\n", - "HCT116 chr13 341.52057003974915\n", - "HCT116 chr14 350.64817333221436\n", - "HCT116 chr15 359.71765422821045\n", - "HCT116 chr16 368.28583669662476\n", - "HCT116 chr17 376.01680874824524\n", - "HCT116 chr18 383.57749581336975\n", - "HCT116 chr19 389.313095331192\n", - "HCT116 chr2 412.09550070762634\n", - "HCT116 chr20 417.96392583847046\n", - "HCT116 chr21 422.00927662849426\n", - "HCT116 chr22 426.71226167678833\n", - "HCT116 chr3 442.76402711868286\n", - "HCT116 chr4 461.9360821247101\n", - "HCT116 chr5 478.7654387950897\n", - "HCT116 chr6 495.16735339164734\n", - "HCT116 chr7 511.82248401641846\n", - "HCT116 chr8 525.3609001636505\n", - "HCT116 chr9 538.2295203208923\n", - "HCT116 chrX 553.2177627086639\n", - "HCT116 555.6791486740112\n", - "HeLa-S3 chr1 580.0471041202545\n", - "HeLa-S3 chr10 594.5216126441956\n", - "HeLa-S3 chr11 606.1479568481445\n", - "HeLa-S3 chr12 618.3873989582062\n", - "HeLa-S3 chr13 628.7777881622314\n", - "HeLa-S3 chr14 637.718688249588\n", - "HeLa-S3 chr15 647.1737523078918\n", - "HeLa-S3 chr16 655.4854662418365\n", - "HeLa-S3 chr17 662.5001983642578\n", - "HeLa-S3 chr18 671.1846849918365\n", - "HeLa-S3 chr19 677.9798579216003\n", - "HeLa-S3 chr2 700.6258955001831\n", - "HeLa-S3 chr20 706.6806621551514\n", - "HeLa-S3 chr21 710.3620142936707\n", - "HeLa-S3 chr22 714.1444058418274\n", - "HeLa-S3 chr3 733.4964163303375\n", - "HeLa-S3 chr4 751.4331798553467\n", - "HeLa-S3 chr5 768.5986630916595\n", - "HeLa-S3 chr6 784.270875453949\n", - "HeLa-S3 chr7 797.38880443573\n", - "HeLa-S3 chr8 810.2063493728638\n", - "HeLa-S3 chr9 822.0966618061066\n", - "HeLa-S3 chrX 836.4143695831299\n", - "HeLa-S3 838.8340592384338\n", - "HepG2 chr1 859.2195272445679\n", - "HepG2 chr10 871.4736864566803\n", - "HepG2 chr11 882.3986892700195\n", - "HepG2 chr12 892.8185703754425\n", - "HepG2 chr13 902.5888631343842\n", - "HepG2 chr14 910.1683735847473\n", - "HepG2 chr15 917.9265463352203\n", - "HepG2 chr16 924.9525022506714\n", - "HepG2 chr17 931.0623359680176\n", - "HepG2 chr18 936.3413846492767\n", - "HepG2 chr19 940.5808892250061\n", - "HepG2 chr2 956.4598441123962\n", - "HepG2 chr20 961.2141344547272\n", - "HepG2 chr21 964.7532434463501\n", - "HepG2 chr22 967.2651414871216\n", - "HepG2 chr3 979.6278064250946\n", - "HepG2 chr4 992.0893275737762\n", - "HepG2 chr5 1003.5635616779327\n", - "HepG2 chr6 1013.5618009567261\n", - "HepG2 chr7 1023.0335354804993\n", - "HepG2 chr8 1031.6841459274292\n", - "HepG2 chr9 1040.966101884842\n", - "HepG2 chrX 1050.8653886318207\n", - "HepG2 1052.4600455760956\n", - "K562 chr1 1067.425773382187\n", - "K562 chr10 1077.6573045253754\n", - "K562 chr11 1090.1371166706085\n", - "K562 chr12 1097.8005549907684\n", - "K562 chr13 1104.0740525722504\n", - "K562 chr14 1110.0484874248505\n", - "K562 chr15 1115.4218137264252\n", - "K562 chr16 1120.9185173511505\n", - "K562 chr17 1125.8319385051727\n", - "K562 chr18 1130.6154806613922\n", - "K562 chr19 1134.5980072021484\n", - "K562 chr2 1149.40651845932\n", - "K562 chr20 1153.4179229736328\n", - "K562 chr21 1156.8773469924927\n", - "K562 chr22 1159.5331387519836\n", - "K562 chr3 1172.5848636627197\n", - "K562 chr4 1184.034875869751\n", - "K562 chr5 1194.7483167648315\n", - "K562 chr6 1205.1025590896606\n", - "K562 chr7 1215.1975507736206\n", - "K562 chr8 1224.603568315506\n", - "K562 chr9 1233.0638110637665\n", - "K562 chrX 1243.617464542389\n", - "K562 1245.0139937400818\n", - "A549 chr1 1261.9037923812866\n", - "A549 chr10 1269.7828676700592\n", - "A549 chr11 1277.8243072032928\n", - "A549 chr12 1285.5817420482635\n", - "A549 chr13 1292.7439670562744\n", - "A549 chr14 1299.0479788780212\n", - "A549 chr15 1305.467554807663\n", - "A549 chr16 1310.6993942260742\n", - "A549 chr17 1315.2565250396729\n", - "A549 chr18 1320.7803528308868\n", - "A549 chr19 1324.2634809017181\n", - "A549 chr2 1342.3504185676575\n", - "A549 chr20 1346.4802606105804\n", - "A549 chr21 1349.3959574699402\n", - "A549 chr22 1352.0359740257263\n", - "A549 chr3 1363.596797466278\n", - "A549 chr4 1377.4319243431091\n", - "A549 chr5 1389.1430621147156\n", - "A549 chr6 1399.9876585006714\n", - "A549 chr7 1410.3144631385803\n", - "A549 chr8 1419.1995024681091\n", - "A549 chr9 1429.0530500411987\n", - "A549 chrX 1438.9812471866608\n", - "A549 1442.8687179088593\n", - "GM12878 chr1 1460.053512096405\n", - "GM12878 chr10 1467.7110249996185\n", - "GM12878 chr11 1477.1048283576965\n", - "GM12878 chr12 1486.3174769878387\n", - "GM12878 chr13 1493.3219420909882\n", - "GM12878 chr14 1499.3263096809387\n", - "GM12878 chr15 1505.1299676895142\n", - "GM12878 chr16 1511.5330748558044\n", - "GM12878 chr17 1516.295937538147\n", - "GM12878 chr18 1521.195916891098\n", - "GM12878 chr19 1524.6872880458832\n", - "GM12878 chr2 1540.392737865448\n", - "GM12878 chr20 1544.6384494304657\n", - "GM12878 chr21 1547.7403523921967\n", - "GM12878 chr22 1550.2115426063538\n", - "GM12878 chr3 1563.652869939804\n", - "GM12878 chr4 1575.806304693222\n", - "GM12878 chr5 1587.0206112861633\n", - "GM12878 chr6 1596.6215176582336\n", - "GM12878 chr7 1606.828714132309\n", - "GM12878 chr8 1616.3425333499908\n", - "GM12878 chr9 1624.599442243576\n", - "GM12878 chrX 1634.0690217018127\n", - "GM12878 1635.526027917862\n", - "1635.5274765491486\n" + "H1-hESC chr1 20.20913076400757\n", + "H1-hESC chr10 28.008360862731934\n", + "H1-hESC chr11 36.41605806350708\n", + "H1-hESC chr12 44.10679650306702\n", + "H1-hESC chr13 51.19197916984558\n", + "H1-hESC chr14 57.02869009971619\n", + "H1-hESC chr15 62.31191349029541\n", + "H1-hESC chr16 67.81044888496399\n", + "H1-hESC chr17 72.55425524711609\n", + "H1-hESC chr18 78.06788182258606\n", + "H1-hESC chr19 81.6804301738739\n", + "H1-hESC chr2 96.18858242034912\n", + "H1-hESC chr20 100.57028126716614\n", + "H1-hESC chr21 103.48856258392334\n", + "H1-hESC chr22 106.65493178367615\n", + "H1-hESC chr3 119.19174075126648\n", + "H1-hESC chr4 131.20506811141968\n", + "H1-hESC chr5 142.5725815296173\n", + "H1-hESC chr6 152.76653385162354\n", + "H1-hESC chr7 162.29314422607422\n", + "H1-hESC chr8 172.217839717865\n", + "H1-hESC chr9 180.51852083206177\n", + "H1-hESC chrX 189.65529799461365\n", + "H1-hESC 190.9764485359192\n", + "HCT116 chr1 206.9115183353424\n", + "HCT116 chr10 214.90280389785767\n", + "HCT116 chr11 223.43896079063416\n", + "HCT116 chr12 231.92131686210632\n", + "HCT116 chr13 238.30261087417603\n", + "HCT116 chr14 244.51456451416016\n", + "HCT116 chr15 250.19079542160034\n", + "HCT116 chr16 255.56156754493713\n", + "HCT116 chr17 260.19018745422363\n", + "HCT116 chr18 264.7617914676666\n", + "HCT116 chr19 268.18313336372375\n", + "HCT116 chr2 282.2316060066223\n", + "HCT116 chr20 286.1222012042999\n", + "HCT116 chr21 288.44735455513\n", + "HCT116 chr22 290.89445447921753\n", + "HCT116 chr3 303.90423917770386\n", + "HCT116 chr4 315.55219316482544\n", + "HCT116 chr5 325.40247106552124\n", + "HCT116 chr6 335.81401777267456\n", + "HCT116 chr7 344.9978699684143\n", + "HCT116 chr8 353.3862988948822\n", + "HCT116 chr9 361.44275426864624\n", + "HCT116 chrX 370.7123851776123\n", + "HCT116 372.3075313568115\n", + "HeLa-S3 chr1 387.19204020500183\n", + "HeLa-S3 chr10 395.3121614456177\n", + "HeLa-S3 chr11 403.74219489097595\n", + "HeLa-S3 chr12 411.4250144958496\n", + "HeLa-S3 chr13 418.4364001750946\n", + "HeLa-S3 chr14 423.98758840560913\n", + "HeLa-S3 chr15 429.3569166660309\n", + "HeLa-S3 chr16 434.58498072624207\n", + "HeLa-S3 chr17 439.32522535324097\n", + "HeLa-S3 chr18 443.5133364200592\n", + "HeLa-S3 chr19 448.0177550315857\n", + "HeLa-S3 chr2 462.9882276058197\n", + "HeLa-S3 chr20 466.51298093795776\n", + "HeLa-S3 chr21 469.0944080352783\n", + "HeLa-S3 chr22 471.5260305404663\n", + "HeLa-S3 chr3 483.85995268821716\n", + "HeLa-S3 chr4 495.36716389656067\n", + "HeLa-S3 chr5 507.5515208244324\n", + "HeLa-S3 chr6 517.3091206550598\n", + "HeLa-S3 chr7 526.7820916175842\n", + "HeLa-S3 chr8 536.0832920074463\n", + "HeLa-S3 chr9 544.3730075359344\n", + "HeLa-S3 chrX 554.4541621208191\n", + "HeLa-S3 555.8484582901001\n", + "HepG2 chr1 570.9822986125946\n", + "HepG2 chr10 578.9540090560913\n", + "HepG2 chr11 587.8593363761902\n", + "HepG2 chr12 595.8393228054047\n", + "HepG2 chr13 602.6833045482635\n", + "HepG2 chr14 609.1072862148285\n", + "HepG2 chr15 614.747784614563\n", + "HepG2 chr16 620.7166090011597\n", + "HepG2 chr17 625.3372988700867\n", + "HepG2 chr18 629.8560705184937\n", + "HepG2 chr19 633.2941951751709\n", + "HepG2 chr2 648.9851665496826\n", + "HepG2 chr20 652.6327149868011\n", + "HepG2 chr21 655.2463212013245\n", + "HepG2 chr22 658.0303378105164\n", + "HepG2 chr3 669.5658092498779\n", + "HepG2 chr4 682.1893765926361\n", + "HepG2 chr5 692.735545873642\n", + "HepG2 chr6 703.3878519535065\n", + "HepG2 chr7 713.1193616390228\n", + "HepG2 chr8 722.3643298149109\n", + "HepG2 chr9 730.6321549415588\n", + "HepG2 chrX 740.6085741519928\n", + "HepG2 742.0141065120697\n", + "K562 chr1 758.5779891014099\n", + "K562 chr10 767.4113881587982\n", + "K562 chr11 775.8454809188843\n", + "K562 chr12 783.9884355068207\n", + "K562 chr13 791.1310849189758\n", + "K562 chr14 797.672040939331\n", + "K562 chr15 803.6226804256439\n", + "K562 chr16 808.6951246261597\n", + "K562 chr17 813.4657413959503\n", + "K562 chr18 817.7939398288727\n", + "K562 chr19 821.4753954410553\n", + "K562 chr2 836.5985827445984\n", + "K562 chr20 840.5005617141724\n", + "K562 chr21 843.5687084197998\n", + "K562 chr22 847.0207371711731\n", + "K562 chr3 859.3639187812805\n", + "K562 chr4 872.1509864330292\n", + "K562 chr5 883.8973982334137\n", + "K562 chr6 894.884886264801\n", + "K562 chr7 904.8027040958405\n", + "K562 chr8 913.5210344791412\n", + "K562 chr9 923.3883426189423\n", + "K562 chrX 932.1968190670013\n", + "K562 933.5820450782776\n", + "A549 chr1 949.486894607544\n", + "A549 chr10 958.3895864486694\n", + "A549 chr11 967.0816054344177\n", + "A549 chr12 975.5747411251068\n", + "A549 chr13 982.0127573013306\n", + "A549 chr14 987.795184135437\n", + "A549 chr15 993.7581944465637\n", + "A549 chr16 999.4659056663513\n", + "A549 chr17 1004.0420877933502\n", + "A549 chr18 1009.2555639743805\n", + "A549 chr19 1012.6145355701447\n", + "A549 chr2 1026.8949675559998\n", + "A549 chr20 1030.6796989440918\n", + "A549 chr21 1033.2921645641327\n", + "A549 chr22 1035.7836105823517\n", + "A549 chr3 1047.997889995575\n", + "A549 chr4 1060.1437220573425\n", + "A549 chr5 1073.8904914855957\n", + "A549 chr6 1084.380021572113\n", + "A549 chr7 1093.8411493301392\n", + "A549 chr8 1103.4142487049103\n", + "A549 chr9 1111.7713177204132\n", + "A549 chrX 1121.4708564281464\n", + "A549 1123.1007504463196\n", + "GM12878 chr1 1139.6929399967194\n", + "GM12878 chr10 1149.3214106559753\n", + "GM12878 chr11 1156.9325966835022\n", + "GM12878 chr12 1165.1995975971222\n", + "GM12878 chr13 1171.642731666565\n", + "GM12878 chr14 1177.269201040268\n", + "GM12878 chr15 1182.3881227970123\n", + "GM12878 chr16 1189.4382717609406\n", + "GM12878 chr17 1195.1939060688019\n", + "GM12878 chr18 1199.5745940208435\n", + "GM12878 chr19 1202.9597895145416\n", + "GM12878 chr2 1218.1304664611816\n", + "GM12878 chr20 1222.2568864822388\n", + "GM12878 chr21 1225.2206535339355\n", + "GM12878 chr22 1228.303787469864\n", + "GM12878 chr3 1240.6525540351868\n", + "GM12878 chr4 1251.8951542377472\n", + "GM12878 chr5 1263.021770954132\n", + "GM12878 chr6 1273.153335094452\n", + "GM12878 chr7 1283.4220962524414\n", + "GM12878 chr8 1293.7367491722107\n", + "GM12878 chr9 1301.8155844211578\n", + "GM12878 chrX 1310.8613522052765\n", + "GM12878 1312.4627130031586\n", + "1312.4639494419098\n" ] } ], @@ -352,7 +361,7 @@ " for chrID in bw.chroms():\n", " chromsize = bw.chroms()[chrID]\n", " # Iterate over windows\n", - " for startc in np.arange(0, chromsize, stride):\n", + " for startc in np.arange(0, chromsize-(2*stride), stride):\n", " u_end = startc + stride\n", " if u_end > chromsize:\n", " break\n", @@ -372,16 +381,35 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "1635.9203071594238\n", - "1644.0572729110718\n", - "1665.5485808849335\n" + "{'chr1': 249250621, 'chr10': 135534747, 'chr11': 135006516, 'chr12': 133851895, 'chr13': 115169878, 'chr14': 107349540, 'chr15': 102531392, 'chr16': 90354753, 'chr17': 81195210, 'chr18': 78077248, 'chr19': 59128983, 'chr2': 243199373, 'chr20': 63025520, 'chr21': 48129895, 'chr22': 51304566, 'chr3': 198022430, 'chr4': 191154276, 'chr5': 180915260, 'chr6': 171115067, 'chr7': 159138663, 'chr8': 146364022, 'chr9': 141213431, 'chrX': 155270560}\n" + ] + } + ], + "source": [ + "bw = pyBigWig.open(ct_labels_bw_path)\n", + "print(bw.chroms())\n", + "bw.close()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "4135.834958076477\n", + "4144.659387350082\n", + "4163.944329023361\n" ] } ], @@ -396,6 +424,154 @@ "np.save(_data_dir + 'labels/MAX/all_metadata_y.npy', np.vstack(celltype_labels))\n", "print(time.time() - itime)" ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
chrstartstopcelltype
0chr106400H1-hESC
1chr1640012800H1-hESC
2chr11280019200H1-hESC
3chr11920025600H1-hESC
4chr12560032000H1-hESC
...............
474383chrX155232000155238400GM12878
474384chrX155238400155244800GM12878
474385chrX155244800155251200GM12878
474386chrX155251200155257600GM12878
474387chrX155257600155264000GM12878
\n", + "

3320716 rows × 4 columns

\n", + "
" + ], + "text/plain": [ + " chr start stop celltype\n", + "0 chr1 0 6400 H1-hESC\n", + "1 chr1 6400 12800 H1-hESC\n", + "2 chr1 12800 19200 H1-hESC\n", + "3 chr1 19200 25600 H1-hESC\n", + "4 chr1 25600 32000 H1-hESC\n", + "... ... ... ... ...\n", + "474383 chrX 155232000 155238400 GM12878\n", + "474384 chrX 155238400 155244800 GM12878\n", + "474385 chrX 155244800 155251200 GM12878\n", + "474386 chrX 155251200 155257600 GM12878\n", + "474387 chrX 155257600 155264000 GM12878\n", + "\n", + "[3320716 rows x 4 columns]" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "all_metadata_df" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/dataset_preprocessing/encode-tfbs/prep_metadata_labels.py b/dataset_preprocessing/encode-tfbs/prep_metadata_labels.py index de156920..12352e53 100644 --- a/dataset_preprocessing/encode-tfbs/prep_metadata_labels.py +++ b/dataset_preprocessing/encode-tfbs/prep_metadata_labels.py @@ -76,6 +76,52 @@ def write_label_bigwigs(): bw.close() +def write_(): + stride = 6400 + itime = time.time() + mdf_posamb = pd.read_csv( + _sorted_dir, + sep='\t', header=None, index_col=None, names=['chr', 'start', 'stop', 'y', 'celltype'] + ) + celltype_mdta = [] + celltype_labels = [] + + for ct in _all_celltypes: + ct_labels_bw_path = _data_dir + "labels/MAX/MAX_{}.bigwig".format(ct) + df = mdf_posamb[mdf_posamb['celltype'] == ct] + df['window_start'] = stride*(df['start'] // stride) + uniq_windows = np.unique(["{}:{}".format(x[0], x[1]) for x in zip(df['chr'], df['window_start'])]) + df_construction = [] + mdta_labels = [] + + bw = pyBigWig.open(ct_labels_bw_path) + num_reps = 0 + for u in uniq_windows: + u_chr = u.split(':')[0] + u_start = int(u.split(':')[1]) + u_end = u_start + stride + x = np.nan_to_num(bw.values(u_chr, u_start, u_end, numpy=True)) + df_construction.append((u_chr, u_start, u_end)) + mdta_labels.append(x[np.arange(0, len(x), 50)]) + num_reps = num_reps + 1 + celltype_mdta_df = pd.DataFrame(df_construction, columns=['chr', 'start', 'stop']) + celltype_mdta_df.insert(len(celltype_mdta_df.columns), 'celltype', ct) + celltype_mdta.append(celltype_mdta_df) + celltype_labels.append(np.stack(mdta_labels)) + print(ct, time.time() - itime) + bw.close() + # break + print(time.time() - itime) + # _metadata_df + + pd.concat(celltype_mdta).to_csv( + _data_dir + 'labels/MAX/metadata_df.bed', + sep='\t', header=False, index=False + ) + np.save(_data_dir + 'labels/MAX/metadata_y.npy', np.vstack(celltype_labels)) + print(time.time() - itime) + + if __name__ == '__main__': write_label_bigwigs() \ No newline at end of file diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index 90de3245..3723086d 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -186,7 +186,7 @@ def get_input(self, idx, window_size=12800): seq_this = self._seq_bp[this_metadata['chr']][interval_start:interval_end] dnase_bw = self._dnase_allcelltypes[this_metadata['celltype']] dnase_this = dnase_bw.values(chrom, interval_start, interval_end, numpy=True) - + assert(np.isnan(seq_this).sum() == 0) assert(np.isnan(dnase_this).sum() == 0) return torch.tensor(np.column_stack( From 5b84bd0e40101d30fc9afc80d78326d61bce0701 Mon Sep 17 00:00:00 2001 From: Tony Lee Date: Mon, 22 Mar 2021 23:35:34 -0700 Subject: [PATCH 102/244] update predictions path --- examples/evaluate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/evaluate.py b/examples/evaluate.py index 7d957f41..a5e7a7e3 100644 --- a/examples/evaluate.py +++ b/examples/evaluate.py @@ -39,7 +39,7 @@ def evaluate_all_benchmarks(predictions_dir: str, output_dir: str, root_dir: str for dataset in benchmark_datasets: try: all_results[dataset] = evaluate_benchmark( - dataset, predictions_dir, output_dir, root_dir + dataset, os.path.join(predictions_dir, dataset), output_dir, root_dir ) except Exception as e: print(f"Could not evaluate predictions for {dataset}:\n{str(e)}") From 88edaafc68ef011b5303791ae65487eb3c48093e Mon Sep 17 00:00:00 2001 From: aikanor Date: Tue, 23 Mar 2021 14:59:00 -0700 Subject: [PATCH 103/244] Modified bundle, metadata preprocessing to include all regions --- .../encode-tfbs/prep_metadata_labels.ipynb | 598 ------------------ .../encode-tfbs/prep_metadata_labels.py | 97 +-- wilds/datasets/encodetfbs_dataset.py | 7 +- 3 files changed, 57 insertions(+), 645 deletions(-) delete mode 100644 dataset_preprocessing/encode-tfbs/prep_metadata_labels.ipynb diff --git a/dataset_preprocessing/encode-tfbs/prep_metadata_labels.ipynb b/dataset_preprocessing/encode-tfbs/prep_metadata_labels.ipynb deleted file mode 100644 index b4040638..00000000 --- a/dataset_preprocessing/encode-tfbs/prep_metadata_labels.ipynb +++ /dev/null @@ -1,598 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import os, csv\n", - "import scipy, numpy as np, pandas as pd, time\n", - "from scipy import sparse\n", - "import pyBigWig, prep_metadata_labels\n", - "\n", - "# Human chromosome names\n", - "chr_IDs = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX']\n", - "_data_dir = '../../examples/data/encode-tfbs_v1.0/'" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Prep metadata df and metadata/label array\n", - "- Metadata df contains 6400bp (window_size/2) prediction windows across the genome. Each gets a 128-bit prediction from the model.\n", - "- We store the ones that aren't fully unbound, and write these to bigwigs representing genome-wide labels." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "collapsed": true, - "jupyter": { - "outputs_hidden": true - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "83.30138063430786\n", - "H1-hESC 100.73247504234314\n", - "HCT116 106.4023334980011\n", - "HeLa-S3 111.88021206855774\n", - "HepG2 117.56940197944641\n", - "K562 126.93423342704773\n", - "A549 138.21517205238342\n", - "GM12878 148.77391648292542\n", - "150.62964010238647\n", - "213.72714066505432\n" - ] - } - ], - "source": [ - "prep_metadata_labels.write_label_bigwigs()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "- Then read from the bigwigs to generate metadata for the bound sites." - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - ":9: SettingWithCopyWarning: \n", - "A value is trying to be set on a copy of a slice from a DataFrame.\n", - "Try using .loc[row_indexer,col_indexer] = value instead\n", - "\n", - "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", - " df['window_start'] = stride*(df['start'] // stride)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "A549 63.97912120819092\n", - "GM12878 103.89278292655945\n", - "H1-hESC 182.84059262275696\n", - "HCT116 243.95744681358337\n", - "HeLa-S3 303.7187397480011\n", - "HepG2 375.8099205493927\n", - "K562 456.08897161483765\n", - "456.0923991203308\n", - "462.8749210834503\n" - ] - } - ], - "source": [ - "stride = 6400\n", - "itime = time.time()\n", - "mdf_posamb = pd.read_csv(\n", - " _sorted_dir, \n", - " sep='\\t', header=None, index_col=None, names=['chr', 'start', 'stop', 'y', 'celltype']\n", - ")\n", - "celltype_mdta = []\n", - "celltype_labels = []\n", - "_train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']\n", - "_val_celltype = ['A549']\n", - "_test_celltype = ['GM12878']\n", - "_all_celltypes = _train_celltypes + _val_celltype + _test_celltype\n", - "\n", - "for ct in _all_celltypes:\n", - " ct_labels_bw_path = _data_dir + \"labels/MAX/MAX_{}.bigwig\".format(ct)\n", - " df = mdf_posamb[mdf_posamb['celltype'] == ct]\n", - " df['window_start'] = stride*(df['start'] // stride)\n", - " uniq_windows = np.unique([\"{}:{}\".format(x[0], x[1]) for x in zip(df['chr'], df['window_start'])])\n", - " df_construction = []\n", - " mdta_labels = []\n", - " \n", - " bw = pyBigWig.open(ct_labels_bw_path)\n", - " num_reps = 0\n", - " for u in uniq_windows:\n", - " u_chr = u.split(':')[0]\n", - " u_start = int(u.split(':')[1])\n", - " u_end = u_start + stride\n", - " x = np.nan_to_num(bw.values(u_chr, u_start, u_end, numpy=True))\n", - " df_construction.append((u_chr, u_start, u_end))\n", - " mdta_labels.append(x[np.arange(0, len(x), 50)])\n", - " num_reps = num_reps + 1\n", - " celltype_mdta_df = pd.DataFrame(df_construction, columns=['chr', 'start', 'stop'])\n", - " celltype_mdta_df.insert(len(celltype_mdta_df.columns), 'celltype', ct)\n", - " celltype_mdta.append(celltype_mdta_df)\n", - " celltype_labels.append(np.stack(mdta_labels))\n", - " print(ct, time.time() - itime)\n", - " bw.close()\n", - " # break\n", - "print(time.time() - itime)\n", - "# _metadata_df\n", - "\n", - "pd.concat(celltype_mdta).to_csv(\n", - " _data_dir + 'labels/MAX/metadata_df.bed', \n", - " sep='\\t', header=False, index=False\n", - ")\n", - "np.save(_data_dir + 'labels/MAX/metadata_y.npy', np.vstack(celltype_labels))\n", - "print(time.time() - itime)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Add the all-unbound sites" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "collapsed": true, - "jupyter": { - "outputs_hidden": true - } - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "H1-hESC chr1 20.20913076400757\n", - "H1-hESC chr10 28.008360862731934\n", - "H1-hESC chr11 36.41605806350708\n", - "H1-hESC chr12 44.10679650306702\n", - "H1-hESC chr13 51.19197916984558\n", - "H1-hESC chr14 57.02869009971619\n", - "H1-hESC chr15 62.31191349029541\n", - "H1-hESC chr16 67.81044888496399\n", - "H1-hESC chr17 72.55425524711609\n", - "H1-hESC chr18 78.06788182258606\n", - "H1-hESC chr19 81.6804301738739\n", - "H1-hESC chr2 96.18858242034912\n", - "H1-hESC chr20 100.57028126716614\n", - "H1-hESC chr21 103.48856258392334\n", - "H1-hESC chr22 106.65493178367615\n", - "H1-hESC chr3 119.19174075126648\n", - "H1-hESC chr4 131.20506811141968\n", - "H1-hESC chr5 142.5725815296173\n", - "H1-hESC chr6 152.76653385162354\n", - "H1-hESC chr7 162.29314422607422\n", - "H1-hESC chr8 172.217839717865\n", - "H1-hESC chr9 180.51852083206177\n", - "H1-hESC chrX 189.65529799461365\n", - "H1-hESC 190.9764485359192\n", - "HCT116 chr1 206.9115183353424\n", - "HCT116 chr10 214.90280389785767\n", - "HCT116 chr11 223.43896079063416\n", - "HCT116 chr12 231.92131686210632\n", - "HCT116 chr13 238.30261087417603\n", - "HCT116 chr14 244.51456451416016\n", - "HCT116 chr15 250.19079542160034\n", - "HCT116 chr16 255.56156754493713\n", - "HCT116 chr17 260.19018745422363\n", - "HCT116 chr18 264.7617914676666\n", - "HCT116 chr19 268.18313336372375\n", - "HCT116 chr2 282.2316060066223\n", - "HCT116 chr20 286.1222012042999\n", - "HCT116 chr21 288.44735455513\n", - "HCT116 chr22 290.89445447921753\n", - "HCT116 chr3 303.90423917770386\n", - "HCT116 chr4 315.55219316482544\n", - "HCT116 chr5 325.40247106552124\n", - "HCT116 chr6 335.81401777267456\n", - "HCT116 chr7 344.9978699684143\n", - "HCT116 chr8 353.3862988948822\n", - "HCT116 chr9 361.44275426864624\n", - "HCT116 chrX 370.7123851776123\n", - "HCT116 372.3075313568115\n", - "HeLa-S3 chr1 387.19204020500183\n", - "HeLa-S3 chr10 395.3121614456177\n", - "HeLa-S3 chr11 403.74219489097595\n", - "HeLa-S3 chr12 411.4250144958496\n", - "HeLa-S3 chr13 418.4364001750946\n", - "HeLa-S3 chr14 423.98758840560913\n", - "HeLa-S3 chr15 429.3569166660309\n", - "HeLa-S3 chr16 434.58498072624207\n", - "HeLa-S3 chr17 439.32522535324097\n", - "HeLa-S3 chr18 443.5133364200592\n", - "HeLa-S3 chr19 448.0177550315857\n", - "HeLa-S3 chr2 462.9882276058197\n", - "HeLa-S3 chr20 466.51298093795776\n", - "HeLa-S3 chr21 469.0944080352783\n", - "HeLa-S3 chr22 471.5260305404663\n", - "HeLa-S3 chr3 483.85995268821716\n", - "HeLa-S3 chr4 495.36716389656067\n", - "HeLa-S3 chr5 507.5515208244324\n", - "HeLa-S3 chr6 517.3091206550598\n", - "HeLa-S3 chr7 526.7820916175842\n", - "HeLa-S3 chr8 536.0832920074463\n", - "HeLa-S3 chr9 544.3730075359344\n", - "HeLa-S3 chrX 554.4541621208191\n", - "HeLa-S3 555.8484582901001\n", - "HepG2 chr1 570.9822986125946\n", - "HepG2 chr10 578.9540090560913\n", - "HepG2 chr11 587.8593363761902\n", - "HepG2 chr12 595.8393228054047\n", - "HepG2 chr13 602.6833045482635\n", - "HepG2 chr14 609.1072862148285\n", - "HepG2 chr15 614.747784614563\n", - "HepG2 chr16 620.7166090011597\n", - "HepG2 chr17 625.3372988700867\n", - "HepG2 chr18 629.8560705184937\n", - "HepG2 chr19 633.2941951751709\n", - "HepG2 chr2 648.9851665496826\n", - "HepG2 chr20 652.6327149868011\n", - "HepG2 chr21 655.2463212013245\n", - "HepG2 chr22 658.0303378105164\n", - "HepG2 chr3 669.5658092498779\n", - "HepG2 chr4 682.1893765926361\n", - "HepG2 chr5 692.735545873642\n", - "HepG2 chr6 703.3878519535065\n", - "HepG2 chr7 713.1193616390228\n", - "HepG2 chr8 722.3643298149109\n", - "HepG2 chr9 730.6321549415588\n", - "HepG2 chrX 740.6085741519928\n", - "HepG2 742.0141065120697\n", - "K562 chr1 758.5779891014099\n", - "K562 chr10 767.4113881587982\n", - "K562 chr11 775.8454809188843\n", - "K562 chr12 783.9884355068207\n", - "K562 chr13 791.1310849189758\n", - "K562 chr14 797.672040939331\n", - "K562 chr15 803.6226804256439\n", - "K562 chr16 808.6951246261597\n", - "K562 chr17 813.4657413959503\n", - "K562 chr18 817.7939398288727\n", - "K562 chr19 821.4753954410553\n", - "K562 chr2 836.5985827445984\n", - "K562 chr20 840.5005617141724\n", - "K562 chr21 843.5687084197998\n", - "K562 chr22 847.0207371711731\n", - "K562 chr3 859.3639187812805\n", - "K562 chr4 872.1509864330292\n", - "K562 chr5 883.8973982334137\n", - "K562 chr6 894.884886264801\n", - "K562 chr7 904.8027040958405\n", - "K562 chr8 913.5210344791412\n", - "K562 chr9 923.3883426189423\n", - "K562 chrX 932.1968190670013\n", - "K562 933.5820450782776\n", - "A549 chr1 949.486894607544\n", - "A549 chr10 958.3895864486694\n", - "A549 chr11 967.0816054344177\n", - "A549 chr12 975.5747411251068\n", - "A549 chr13 982.0127573013306\n", - "A549 chr14 987.795184135437\n", - "A549 chr15 993.7581944465637\n", - "A549 chr16 999.4659056663513\n", - "A549 chr17 1004.0420877933502\n", - "A549 chr18 1009.2555639743805\n", - "A549 chr19 1012.6145355701447\n", - "A549 chr2 1026.8949675559998\n", - "A549 chr20 1030.6796989440918\n", - "A549 chr21 1033.2921645641327\n", - "A549 chr22 1035.7836105823517\n", - "A549 chr3 1047.997889995575\n", - "A549 chr4 1060.1437220573425\n", - "A549 chr5 1073.8904914855957\n", - "A549 chr6 1084.380021572113\n", - "A549 chr7 1093.8411493301392\n", - "A549 chr8 1103.4142487049103\n", - "A549 chr9 1111.7713177204132\n", - "A549 chrX 1121.4708564281464\n", - "A549 1123.1007504463196\n", - "GM12878 chr1 1139.6929399967194\n", - "GM12878 chr10 1149.3214106559753\n", - "GM12878 chr11 1156.9325966835022\n", - "GM12878 chr12 1165.1995975971222\n", - "GM12878 chr13 1171.642731666565\n", - "GM12878 chr14 1177.269201040268\n", - "GM12878 chr15 1182.3881227970123\n", - "GM12878 chr16 1189.4382717609406\n", - "GM12878 chr17 1195.1939060688019\n", - "GM12878 chr18 1199.5745940208435\n", - "GM12878 chr19 1202.9597895145416\n", - "GM12878 chr2 1218.1304664611816\n", - "GM12878 chr20 1222.2568864822388\n", - "GM12878 chr21 1225.2206535339355\n", - "GM12878 chr22 1228.303787469864\n", - "GM12878 chr3 1240.6525540351868\n", - "GM12878 chr4 1251.8951542377472\n", - "GM12878 chr5 1263.021770954132\n", - "GM12878 chr6 1273.153335094452\n", - "GM12878 chr7 1283.4220962524414\n", - "GM12878 chr8 1293.7367491722107\n", - "GM12878 chr9 1301.8155844211578\n", - "GM12878 chrX 1310.8613522052765\n", - "GM12878 1312.4627130031586\n", - "1312.4639494419098\n" - ] - } - ], - "source": [ - "stride = 6400\n", - "itime = time.time()\n", - "mdf_posamb = pd.read_csv(\n", - " _data_dir + 'labels/MAX/MAX_posamb.sorted.bed', \n", - " sep='\\t', header=None, index_col=None, names=['chr', 'start', 'stop', 'y', 'celltype']\n", - ")\n", - "celltype_mdta = []\n", - "celltype_labels = []\n", - "_train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']\n", - "_val_celltype = ['A549']\n", - "_test_celltype = ['GM12878']\n", - "_all_celltypes = _train_celltypes + _val_celltype + _test_celltype\n", - "\n", - "for ct in _all_celltypes:\n", - " ct_labels_bw_path = _data_dir + \"labels/MAX/MAX_{}.bigwig\".format(ct)\n", - " df_construction = []\n", - " mdta_labels = []\n", - " bw = pyBigWig.open(ct_labels_bw_path)\n", - " for chrID in bw.chroms():\n", - " chromsize = bw.chroms()[chrID]\n", - " # Iterate over windows\n", - " for startc in np.arange(0, chromsize-(2*stride), stride):\n", - " u_end = startc + stride\n", - " if u_end > chromsize:\n", - " break\n", - " x = np.nan_to_num(bw.values(chrID, startc, u_end, numpy=True))\n", - " df_construction.append((chrID, startc, u_end))\n", - " mdta_labels.append(x[np.arange(0, len(x), 50)])\n", - " print(ct, chrID, time.time() - itime)\n", - " celltype_mdta_df = pd.DataFrame(df_construction, columns=['chr', 'start', 'stop'])\n", - " celltype_mdta_df.insert(len(celltype_mdta_df.columns), 'celltype', ct)\n", - " celltype_mdta.append(celltype_mdta_df)\n", - " celltype_labels.append(np.stack(mdta_labels))\n", - " print(ct, time.time() - itime)\n", - " bw.close()\n", - " # break\n", - "print(time.time() - itime)" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'chr1': 249250621, 'chr10': 135534747, 'chr11': 135006516, 'chr12': 133851895, 'chr13': 115169878, 'chr14': 107349540, 'chr15': 102531392, 'chr16': 90354753, 'chr17': 81195210, 'chr18': 78077248, 'chr19': 59128983, 'chr2': 243199373, 'chr20': 63025520, 'chr21': 48129895, 'chr22': 51304566, 'chr3': 198022430, 'chr4': 191154276, 'chr5': 180915260, 'chr6': 171115067, 'chr7': 159138663, 'chr8': 146364022, 'chr9': 141213431, 'chrX': 155270560}\n" - ] - } - ], - "source": [ - "bw = pyBigWig.open(ct_labels_bw_path)\n", - "print(bw.chroms())\n", - "bw.close()" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "4135.834958076477\n", - "4144.659387350082\n", - "4163.944329023361\n" - ] - } - ], - "source": [ - "all_metadata_df = pd.concat(celltype_mdta)\n", - "print(time.time() - itime)\n", - "all_metadata_df.to_csv(\n", - " _data_dir + 'labels/MAX/all_metadata_df.bed', \n", - " sep='\\t', header=False, index=False\n", - ")\n", - "print(time.time() - itime)\n", - "np.save(_data_dir + 'labels/MAX/all_metadata_y.npy', np.vstack(celltype_labels))\n", - "print(time.time() - itime)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
chrstartstopcelltype
0chr106400H1-hESC
1chr1640012800H1-hESC
2chr11280019200H1-hESC
3chr11920025600H1-hESC
4chr12560032000H1-hESC
...............
474383chrX155232000155238400GM12878
474384chrX155238400155244800GM12878
474385chrX155244800155251200GM12878
474386chrX155251200155257600GM12878
474387chrX155257600155264000GM12878
\n", - "

3320716 rows × 4 columns

\n", - "
" - ], - "text/plain": [ - " chr start stop celltype\n", - "0 chr1 0 6400 H1-hESC\n", - "1 chr1 6400 12800 H1-hESC\n", - "2 chr1 12800 19200 H1-hESC\n", - "3 chr1 19200 25600 H1-hESC\n", - "4 chr1 25600 32000 H1-hESC\n", - "... ... ... ... ...\n", - "474383 chrX 155232000 155238400 GM12878\n", - "474384 chrX 155238400 155244800 GM12878\n", - "474385 chrX 155244800 155251200 GM12878\n", - "474386 chrX 155251200 155257600 GM12878\n", - "474387 chrX 155257600 155264000 GM12878\n", - "\n", - "[3320716 rows x 4 columns]" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "all_metadata_df" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.5" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/dataset_preprocessing/encode-tfbs/prep_metadata_labels.py b/dataset_preprocessing/encode-tfbs/prep_metadata_labels.py index 12352e53..ca8c142f 100644 --- a/dataset_preprocessing/encode-tfbs/prep_metadata_labels.py +++ b/dataset_preprocessing/encode-tfbs/prep_metadata_labels.py @@ -10,30 +10,24 @@ _data_dir = '../../examples/data/encode-tfbs_v1.0/' -def write_label_bigwigs(): +def write_label_bigwigs(celltypes): itime = time.time() - transcription_factor = 'MAX' + tf_name = 'MAX' _train_chroms = ['chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr10', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX'] _val_chroms = ['chr2', 'chr9', 'chr11'] _test_chroms = ['chr1', 'chr8', 'chr21'] _all_chroms = _train_chroms + _val_chroms + _test_chroms - _train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562'] - _val_celltype = ['A549'] - _test_celltype = ['GM12878'] - _all_celltypes = _train_celltypes + _val_celltype + _test_celltype # Read in metadata dataframe from training+validation data - train_regions_labeled = pd.read_csv(os.path.join(_data_dir, 'labels/{}.train.labels.tsv.gz'.format(_transcription_factor)), sep='\t') - val_regions_labeled = pd.read_csv(os.path.join(_data_dir, 'labels/{}.val.labels.tsv.gz'.format(_transcription_factor)), sep='\t') + train_regions_labeled = pd.read_csv(os.path.join(_data_dir, 'labels/{}.train.labels.tsv.gz'.format(tf_name)), sep='\t') + val_regions_labeled = pd.read_csv(os.path.join(_data_dir, 'labels/{}.val.labels.tsv.gz'.format(tf_name)), sep='\t') training_df = train_regions_labeled# [np.isin(train_regions_labeled['chr'], _train_chroms)] val_df = val_regions_labeled# [np.isin(val_regions_labeled['chr'], _test_chroms)] all_df = pd.concat([training_df, val_df]) - print(time.time() - itime) - # Get the y values, and remove labels by default. pd_list = [] - for ct in _all_celltypes: + for ct in celltypes: tc_chr = all_df[['chr', 'start', 'stop', ct]] tc_chr.columns = ['chr', 'start', 'stop', 'y'] tc_chr = tc_chr[tc_chr['y'] != 'U'] @@ -46,10 +40,10 @@ def write_label_bigwigs(): print(time.time() - itime) _unsorted_dir = _data_dir + 'labels/{}/{}_posamb.bed'.format( - transcription_factor, transcription_factor) + tf_name, tf_name) _sorted_dir = _unsorted_dir.replace( - '{}_posamb'.format(transcription_factor), - '{}_posamb.sorted'.format(transcription_factor) + '{}_posamb'.format(tf_name), + '{}_posamb.sorted'.format(tf_name) ) _metadata_df.to_csv( _unsorted_dir, sep='\t', header=False, index=False @@ -65,9 +59,9 @@ def write_label_bigwigs(): # Write the binned labels to bigwig files - genome-wide labels chromsizes_list = [(k, v) for k, v in chrom_sizes.items()] - for ct in _all_celltypes: + for ct in celltypes: ct_labels_bw_path = _data_dir + "labels/{}/{}_{}.bigwig".format( - transcription_factor, transcription_factor, ct) + tf_name, tf_name, ct) df = mdf_posamb[mdf_posamb['celltype'] == ct] bw = pyBigWig.open(ct_labels_bw_path, "w") bw.addHeader(chromsizes_list) @@ -76,52 +70,65 @@ def write_label_bigwigs(): bw.close() -def write_(): - stride = 6400 +def write_metadata_products(celltypes, stride=6400, posamb_only=False): itime = time.time() + tf_name = 'MAX' + celltype_mdta = [] + celltype_labels = [] mdf_posamb = pd.read_csv( - _sorted_dir, + _data_dir + 'labels/{}/{}_posamb.sorted.bed'.format(tf_name, tf_name), sep='\t', header=None, index_col=None, names=['chr', 'start', 'stop', 'y', 'celltype'] ) - celltype_mdta = [] - celltype_labels = [] - - for ct in _all_celltypes: - ct_labels_bw_path = _data_dir + "labels/MAX/MAX_{}.bigwig".format(ct) - df = mdf_posamb[mdf_posamb['celltype'] == ct] - df['window_start'] = stride*(df['start'] // stride) - uniq_windows = np.unique(["{}:{}".format(x[0], x[1]) for x in zip(df['chr'], df['window_start'])]) + # Retrieve only the windows containing positively/ambiguously labeled bins (if posamb_only==True), or all windows (if posamb_only==False). + for ct in celltypes: + ct_labels_bw_path = _data_dir + "labels/{}/{}_{}.bigwig".format(tf_name, tf_name, ct) df_construction = [] mdta_labels = [] - bw = pyBigWig.open(ct_labels_bw_path) - num_reps = 0 - for u in uniq_windows: - u_chr = u.split(':')[0] - u_start = int(u.split(':')[1]) - u_end = u_start + stride - x = np.nan_to_num(bw.values(u_chr, u_start, u_end, numpy=True)) - df_construction.append((u_chr, u_start, u_end)) - mdta_labels.append(x[np.arange(0, len(x), 50)]) - num_reps = num_reps + 1 + if posamb_only: # Retrieve only the windows containing positively/ambiguously labeled bins + df = mdf_posamb[mdf_posamb['celltype'] == ct] + df['window_start'] = stride*(df['start'] // stride) + uniq_windows = np.unique(["{}:{}".format(x[0], x[1]) for x in zip(df['chr'], df['window_start'])]) + for u in uniq_windows: + u_chr = u.split(':')[0] + u_start = int(u.split(':')[1]) + u_end = u_start + stride + x = np.nan_to_num(bw.values(u_chr, u_start, u_end, numpy=True)) + df_construction.append((u_chr, u_start, u_end)) + mdta_labels.append(x[np.arange(0, len(x), 50)]) + else: # Retrieve all windows genome-wide + for chrID in bw.chroms(): + chromsize = bw.chroms()[chrID] + # Iterate over windows + for startc in np.arange(int(stride/2), chromsize-(2*stride), stride): + u_end = startc + stride + if u_end > chromsize: + break + x = np.nan_to_num(bw.values(chrID, startc, u_end, numpy=True)) + df_construction.append((chrID, startc, u_end)) + mdta_labels.append(x[np.arange(0, len(x), 50)]) + print(ct, chrID, time.time() - itime) celltype_mdta_df = pd.DataFrame(df_construction, columns=['chr', 'start', 'stop']) celltype_mdta_df.insert(len(celltype_mdta_df.columns), 'celltype', ct) celltype_mdta.append(celltype_mdta_df) celltype_labels.append(np.stack(mdta_labels)) print(ct, time.time() - itime) bw.close() - # break print(time.time() - itime) - # _metadata_df - - pd.concat(celltype_mdta).to_csv( - _data_dir + 'labels/MAX/metadata_df.bed', + + all_metadata_df = pd.concat(celltype_mdta) + all_metadata_df.to_csv( + _data_dir + 'labels/{}/metadata_df.bed'.format(tf_name), sep='\t', header=False, index=False ) - np.save(_data_dir + 'labels/MAX/metadata_y.npy', np.vstack(celltype_labels)) - print(time.time() - itime) + np.save(_data_dir + 'labels/{}/metadata_y.npy'.format(tf_name), np.vstack(celltype_labels)) if __name__ == '__main__': - write_label_bigwigs() + _train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562'] + _val_celltype = ['A549'] + _test_celltype = ['GM12878'] + _all_celltypes = _train_celltypes + _val_celltype + _test_celltype + write_label_bigwigs(_all_celltypes) + write_metadata_products(_all_celltypes) \ No newline at end of file diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index 3723086d..6fca0d06 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -28,7 +28,7 @@ class EncodeTFBSDataset(WILDSDataset): _dataset_name = 'encode-tfbs' _versions_dict = { '1.0': { - 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x8b3255e21e164cd98d3aeec09cd0bc26/contents/blob/', + 'download_url': 'https://worksheets.codalab.org/rest/bundles/0xf1fdad4a8af1449eb519bc89d4af8f0a/contents/blob/', 'compressed_size': None}} def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'): @@ -185,7 +185,10 @@ def get_input(self, idx, window_size=12800): interval_end = interval_start + window_size seq_this = self._seq_bp[this_metadata['chr']][interval_start:interval_end] dnase_bw = self._dnase_allcelltypes[this_metadata['celltype']] - dnase_this = dnase_bw.values(chrom, interval_start, interval_end, numpy=True) + try: + dnase_this = dnase_bw.values(chrom, interval_start, interval_end, numpy=True) + except RuntimeError: + print("error", chrom, interval_start, interval_end) assert(np.isnan(seq_this).sum() == 0) assert(np.isnan(dnase_this).sum() == 0) From b9868cf5508b49765a74495e059d79a002798804 Mon Sep 17 00:00:00 2001 From: aikanor Date: Tue, 23 Mar 2021 18:33:30 -0700 Subject: [PATCH 104/244] dataset fix --- wilds/datasets/encodetfbs_dataset.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index 6fca0d06..00f856c8 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -127,6 +127,11 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' self._split_array[chrom_mask & celltype_mask] = self._split_dict[split] indices_to_keep = (self._split_array != -1) + # Remove all-zero sequences from training. + train_msk = (self._split_array == full_dataset.split_dict['train']) + allzeroes_msk = (self._y_array.sum(axis=1) == 0).numpy() + indices_to_keep = indices_to_keep & ~(train_msk & allzeroes_msk) + self._metadata_df = self._metadata_df[indices_to_keep] self._split_array = self._split_array[indices_to_keep] self._y_array = self._y_array[indices_to_keep] From 1ba9696cdb132510140d396ead2046a26971472a Mon Sep 17 00:00:00 2001 From: aikanor Date: Tue, 23 Mar 2021 18:35:44 -0700 Subject: [PATCH 105/244] dataset fix --- wilds/datasets/encodetfbs_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index 00f856c8..dcaf653f 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -128,7 +128,7 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' indices_to_keep = (self._split_array != -1) # Remove all-zero sequences from training. - train_msk = (self._split_array == full_dataset.split_dict['train']) + train_msk = (self._split_array == self._split_dict['train']) allzeroes_msk = (self._y_array.sum(axis=1) == 0).numpy() indices_to_keep = indices_to_keep & ~(train_msk & allzeroes_msk) From c53239e7f95c8a570448fd0ba657e402fe5fb606 Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Tue, 23 Mar 2021 20:26:32 -0700 Subject: [PATCH 106/244] no-op --- wilds/datasets/encodetfbs_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index 6fca0d06..65cd6e4f 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -189,7 +189,7 @@ def get_input(self, idx, window_size=12800): dnase_this = dnase_bw.values(chrom, interval_start, interval_end, numpy=True) except RuntimeError: print("error", chrom, interval_start, interval_end) - + assert(np.isnan(seq_this).sum() == 0) assert(np.isnan(dnase_this).sum() == 0) return torch.tensor(np.column_stack( From a06b5aceb2d0912d0ac9f7d5a96508ed765e42ed Mon Sep 17 00:00:00 2001 From: aikanor Date: Wed, 24 Mar 2021 13:36:56 -0700 Subject: [PATCH 107/244] subsampling chromosomes --- wilds/datasets/encodetfbs_dataset.py | 31 +++++++++++++++++++++------- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index dcaf653f..a509df42 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -16,7 +16,7 @@ class EncodeTFBSDataset(WILDSDataset): 12800-base-pair regions of sequence with a quantified chromatin accessibility readout. Label (y): - y is binary. It is 1 if the central 200bp region is bound by the transcription factor MAX, and 0 otherwise. + y is a 128-bit vector, with each element y_i indicating the binding status of a 200bp window. It is 1 if this 200bp region is bound by the transcription factor, and 0 otherwise. If the window x starts at coordinate sc, y_i is the label of the window starting at coordinate (sc+3200)+(50*i). Metadata: Each sequence is annotated with the celltype of origin (a string) and the chromosome of origin (a string). @@ -28,7 +28,7 @@ class EncodeTFBSDataset(WILDSDataset): _dataset_name = 'encode-tfbs' _versions_dict = { '1.0': { - 'download_url': 'https://worksheets.codalab.org/rest/bundles/0xf1fdad4a8af1449eb519bc89d4af8f0a/contents/blob/', + 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x7efd626149d648f699d9e686d7aa81a9/contents/blob/', 'compressed_size': None}} def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'): @@ -54,9 +54,9 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' self.y_array[self.y_array == 0.5] = float('nan') # Construct splits - train_chroms = ['chr3']#, 'chr4', 'chr5', 'chr6', 'chr7', 'chr10', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX'] - val_chroms = ['chr2']#, 'chr9', 'chr11'] - test_chroms = ['chr1']#, 'chr8', 'chr21'] + train_chroms = ['chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr10', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX'] + val_chroms = ['chr2', 'chr9', 'chr11'] + test_chroms = ['chr1', 'chr8', 'chr21'] train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562'] val_celltype = ['A549'] test_celltype = ['GM12878'] @@ -128,9 +128,24 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' indices_to_keep = (self._split_array != -1) # Remove all-zero sequences from training. - train_msk = (self._split_array == self._split_dict['train']) - allzeroes_msk = (self._y_array.sum(axis=1) == 0).numpy() - indices_to_keep = indices_to_keep & ~(train_msk & allzeroes_msk) + remove_allnegative = True + if remove_allnegative: + train_msk = (self._split_array == self._split_dict['train']) + allzeroes_msk = (self._y_array.sum(axis=1) == 0).numpy() + indices_to_keep = indices_to_keep & ~(train_msk & allzeroes_msk) + # Subsample the testing and validation indices + val_msk = (self._split_array == self._split_dict['val']) + test_msk = (self._split_array == self._split_dict['test']) + idval_msk = (self._split_array == self._split_dict['id_val']) + subsamp_factor_id = 15 + subsamp_factor_ood = 3 + + keep_mask_ood = np.random.binomial(1, 1.0/subsamp_factor_ood, size=len(indices_to_keep)).astype(bool) + indices_to_keep = indices_to_keep & ~(~keep_mask_ood & val_msk) + indices_to_keep = indices_to_keep & ~(~keep_mask_ood & test_msk) + + keep_mask_id = np.random.binomial(1, 1.0/subsamp_factor_id, size=len(indices_to_keep)).astype(bool) + indices_to_keep = indices_to_keep & ~(~keep_mask_id & idval_msk) self._metadata_df = self._metadata_df[indices_to_keep] self._split_array = self._split_array[indices_to_keep] From f9276904f3c127189b135a2bf0de2d8a551ea3ec Mon Sep 17 00:00:00 2001 From: aikanor Date: Thu, 25 Mar 2021 06:37:53 -0700 Subject: [PATCH 108/244] small edits + addition to metrics --- examples/configs/datasets.py | 6 +- examples/configs/supported.py | 3 +- sandbox_data.ipynb | 620 ------------------ sandbox_model.ipynb | 982 ---------------------------- wilds/common/metrics/all_metrics.py | 60 +- 5 files changed, 39 insertions(+), 1632 deletions(-) delete mode 100644 sandbox_data.ipynb delete mode 100644 sandbox_model.ipynb diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index 4618a236..1e3ce0c5 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -118,10 +118,10 @@ 'val_metric_decreasing': False, 'optimizer': 'Adam', 'scheduler': None, - 'batch_size': 64, + 'batch_size': 128, 'lr': 0.0001, - 'weight_decay': 0.01, - 'n_epochs': 5, + 'weight_decay': 1e-5, + 'n_epochs': 1, 'n_groups_per_batch': 2, 'algo_log_metric': 'multitask_binary_accuracy', # 'irm_lambda': 1.0, diff --git a/examples/configs/supported.py b/examples/configs/supported.py index 6c60ea53..a49732a8 100644 --- a/examples/configs/supported.py +++ b/examples/configs/supported.py @@ -4,7 +4,7 @@ # metrics from wilds.common.metrics.loss import ElementwiseLoss, Loss, MultiTaskLoss -from wilds.common.metrics.all_metrics import Accuracy, MultiTaskAccuracy, MSE, multiclass_logits_to_pred, binary_logits_to_pred, MultiTaskAveragePrecision, MTAveragePrecision +from wilds.common.metrics.all_metrics import Accuracy, MultiTaskAccuracy, MSE, multiclass_logits_to_pred, binary_logits_to_pred, MultiTaskAveragePrecision, MultiTaskPREven losses = { 'cross_entropy': ElementwiseLoss(loss_fn=nn.CrossEntropyLoss(reduction='none')), @@ -19,6 +19,7 @@ 'multitask_accuracy': MultiTaskAccuracy(prediction_fn=multiclass_logits_to_pred), 'multitask_binary_accuracy': MultiTaskAccuracy(prediction_fn=binary_logits_to_pred), 'multitask_avgprec': MultiTaskAveragePrecision(prediction_fn=None), + # 'multitask_preven': MultiTaskPREven(prediction_fn=None), None: None, } diff --git a/sandbox_data.ipynb b/sandbox_data.ipynb deleted file mode 100644 index c465e0ab..00000000 --- a/sandbox_data.ipynb +++ /dev/null @@ -1,620 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "- examples\n", - " - run_expt.py\n", - " - configs\n", - " - [x] supported.py\n", - " - [x] model.py\n", - " - [x] datasets.py\n", - " - models\n", - " - [x] CNN_genome.py\n", - " - train.py\n", - " - utils.py\n", - "- wilds\n", - " - [x] datasets/encodetfbs_dataset.py\n", - " - common\n", - " - metrics\n", - " - [x] all_metrics.py\n", - " - data_loaders.py\n", - " - grouper.py\n", - " - [x] utils.py ( threshold_at_recall() )" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# TODOs\n", - "\n", - "- change evaluation/validation metric\n", - " - [ ] examples/configs/datasets.py\n", - "- Add `RELEASE_v1.0.txt` to codalab archive\n", - "- Citation/license for wilds/datasets/encodetfbs_dataset.py\n", - "- (optional) change sequence length of model\n", - " - [ ] examples/configs/model.py\n", - " - [ ] examples/models/CNN_genome.py" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Initialize dataset object" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "50.2965240479\n", - "58.1326179504\n" - ] - } - ], - "source": [ - "import numpy as np, pandas as pd, os, time\n", - "import torch, torchvision\n", - "\n", - "data_dir = '/oak/stanford/groups/akundaje/abalsubr/DREAM/wilds/codalab_archive/'\n", - "tf = 'MAX'\n", - "itime = time.time()\n", - "train_chr = pd.read_csv(os.path.join(data_dir, 'labels/{}.train.labels.tsv.gz'.format(tf)), sep='\\t')\n", - "print(time.time() - itime)\n", - "val_chr = pd.read_csv(os.path.join(data_dir, 'labels/{}.val.labels.tsv.gz'.format(tf)), sep='\\t')\n", - "print(time.time() - itime)" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']\n", - "val_celltype = ['A549']\n", - "test_celltype = ['GM12878']\n", - "all_celltypes = train_celltypes + val_celltype + test_celltype\n", - "\n", - "metadata_map = {}\n", - "metadata_map['chr'] = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX']\n", - "metadata_map['celltype'] = all_celltypes\n", - "\n", - "_split_dict = {\n", - " 'train': 0,\n", - " 'val-id': 1,\n", - " 'test': 2,\n", - " 'val-ood': 3\n", - "}\n", - "_split_names = {\n", - " 'train': 'Train',\n", - " 'val-id': 'Validation (ID)',\n", - " 'test': 'Test',\n", - " 'val-ood': 'Validation (OOD)',\n", - "}\n", - "_split_scheme = 'standard'" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "1.40137600899\n", - "('chr1', 4.365410089492798)\n", - "('chr2', 8.54686713218689)\n", - "('chr3', 11.915641069412231)\n", - "('chr4', 15.147382020950317)\n", - "('chr5', 18.221237182617188)\n", - "('chr6', 21.16081714630127)\n", - "('chr7', 23.87936806678772)\n", - "('chr8', 26.382845163345337)\n", - "('chr9', 28.802964210510254)\n", - "('chr10', 31.10539698600769)\n", - "('chr11', 33.392733097076416)\n", - "('chr12', 35.6597261428833)\n", - "('chr13', 37.56297421455383)\n", - "('chr14', 39.363978147506714)\n", - "('chr15', 41.089357137680054)\n", - "('chr16', 42.6117000579834)\n", - "('chr17', 43.9806342124939)\n", - "('chr18', 45.29493808746338)\n", - "('chr19', 46.26894497871399)\n", - "('chr20', 47.31300115585327)\n", - "('chr21', 48.139018058776855)\n", - "('chr22', 48.97876214981079)\n", - "('chrX', 51.61549210548401)\n", - "('H1-hESC', 24.14024806022644)\n", - "('HCT116', 47.97159004211426)\n", - "('HeLa-S3', 72.82926392555237)\n", - "('HepG2', 97.18733406066895)\n", - "('K562', 121.94148206710815)\n", - "('A549', 147.29550194740295)\n", - "('GM12878', 171.71312499046326)\n" - ] - } - ], - "source": [ - "itime = time.time()\n", - "sequence_filename = os.path.join(data_dir, 'sequence.npz')\n", - "seq_arr = np.load(sequence_filename)\n", - "print(time.time() - itime)\n", - "\n", - "itime = time.time()\n", - "_seq_bp = {}\n", - "for chrom in seq_arr:\n", - " _seq_bp[chrom] = seq_arr[chrom]\n", - " print(chrom, time.time() - itime)\n", - "print(\"Sequence read. Time: {}\".format(time.time() - itime))\n", - "\n", - "itime = time.time()\n", - "_dnase_allcelltypes = {}\n", - "for ct in all_celltypes:\n", - " dnase_filename = os.path.join(data_dir, '{}_dnase.npz'.format(ct))\n", - " dnase_npz_file = np.load(dnase_filename)\n", - " _dnase_allcelltypes[ct] = {}\n", - " for chrom in _seq_bp:\n", - " _dnase_allcelltypes[ct][chrom] = dnase_npz_file[chrom]\n", - " print(ct, time.time() - itime)\n", - "print(\"DNase read for all celltypes. Time: {}\".format(time.time() - itime))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "ename": "NameError", - "evalue": "name 'all_df' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# len(_dnase_allcelltypes)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mall_df\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;31mNameError\u001b[0m: name 'all_df' is not defined" - ] - } - ], - "source": [ - "# len(_dnase_allcelltypes)\n", - "all_df" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "ename": "AttributeError", - "evalue": "'module' object has no attribute 'isin'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mtr_chrs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m'chr2'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr9'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr11'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mte_chrs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m'chr1'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr8'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr21'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mtraining_df\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'chr'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtr_chrs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mval_df\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mval_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mval_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'chr'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mte_chrs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mall_df\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtraining_df\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_df\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mAttributeError\u001b[0m: 'module' object has no attribute 'isin'" - ] - } - ], - "source": [ - "tr_chrs = ['chr2', 'chr9', 'chr11']\n", - "te_chrs = ['chr1', 'chr8', 'chr21']\n", - "training_df = train_chr[np.isin(train_chr['chr'], tr_chrs)]\n", - "val_df = val_chr[np.isin(val_chr['chr'], te_chrs)]\n", - "all_df = pd.concat([training_df, val_df])\n", - "\n", - "#filter_msk = all_df['start'] >= 0\n", - "filter_msk = all_df['start']%1000 == 0\n", - "all_df = all_df[filter_msk]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "itime = time.time()\n", - "pd_list = []\n", - "for ct in all_celltypes:\n", - " tc_chr = all_df[['chr', 'start', 'stop', ct]]\n", - " tc_chr.columns = ['chr', 'start', 'stop', 'y']\n", - " tc_chr['celltype'] = ct\n", - " pd_list.append(tc_chr)\n", - "metadata_df = pd.concat(pd_list)\n", - "print(time.time() - itime)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "itime = time.time()\n", - "y_array = metadata_df['y'].replace({'U': 0, 'B': 1, 'A': -1}).values\n", - "non_ambig_mask = (y_array != -1)\n", - "metadata_df['y'] = y_array\n", - "_metadata_df = metadata_df[non_ambig_mask]\n", - "_y_array = torch.LongTensor(y_array[non_ambig_mask])\n", - "print(time.time() - itime)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "itime = time.time()\n", - "chr_ints = _metadata_df['chr'].replace(dict( [(y, x) for x, y in enumerate(metadata_map['chr'])] )).values\n", - "celltype_ints = _metadata_df['celltype'].replace(dict( [(y, x) for x, y in enumerate(metadata_map['celltype'])] )).values\n", - "print(time.time() - itime)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "train_chr_mask = np.isin(_metadata_df['chr'], tr_chrs)\n", - "val_chr_mask = np.isin(_metadata_df['chr'], te_chrs)\n", - "train_celltype_mask = np.isin(_metadata_df['celltype'], train_celltypes)\n", - "val_celltype_mask = np.isin(_metadata_df['celltype'], val_celltype)\n", - "test_celltype_mask = np.isin(_metadata_df['celltype'], test_celltype)\n", - "\n", - "split_array = -1*np.ones(_metadata_df.shape[0]).astype(int)\n", - "split_array[np.logical_and(train_chr_mask, train_celltype_mask)] = _split_dict['train']\n", - "split_array[np.logical_and(val_chr_mask, test_celltype_mask)] = _split_dict['test']\n", - "split_array[np.logical_and(val_chr_mask, val_celltype_mask)] = _split_dict['val-ood']\n", - "split_array[np.logical_and(val_chr_mask, train_celltype_mask)] = _split_dict['val-id']\n", - "_metadata_df['split'] = split_array\n", - "_split_array = split_array" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "ename": "ImportError", - "evalue": "No module named data", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mDataLoader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mdataset_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;31mImportError\u001b[0m: No module named data" - ] - } - ], - "source": [ - "from torch.utils.data import DataLoader\n", - "from data import dataset_attributes" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [], - "source": [ - "from PIL import Image\n", - "import argparse\n", - "class ParseKwargs(argparse.Action):\n", - " def __call__(self, parser, namespace, values, option_string=None):\n", - " setattr(namespace, self.dest, dict())\n", - " for value in values:\n", - " key, value_str = value.split('=')\n", - " if value_str.replace('-','').isnumeric():\n", - " processed_val = int(value_str)\n", - " elif value_str.replace('-','').replace('.','').isnumeric():\n", - " processed_val = float(value_str)\n", - " elif value_str in ['True', 'true']:\n", - " processed_val = True\n", - " elif value_str in ['False', 'false']:\n", - " processed_val = False\n", - " else:\n", - " processed_val = value_str\n", - " getattr(namespace, self.dest)[key] = processed_val" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [ - { - "ename": "NameError", - "evalue": "name 'algorithm_constructors' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 31\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[0;31m# Algorithm and objective\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 33\u001b[0;31m \u001b[0mparser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_argument\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'--algorithm'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrequired\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mchoices\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0malgorithm_constructors\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 34\u001b[0m \u001b[0mparser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_argument\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'--algorithm_kwargs'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnargs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'*'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maction\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mParseKwargs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdefault\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 35\u001b[0m \u001b[0mparser\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd_argument\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'--groupby_fields'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnargs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'+'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdefault\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mNameError\u001b[0m: name 'algorithm_constructors' is not defined" - ] - } - ], - "source": [ - "ROOTDIR = '/oak/stanford/groups/akundaje/abalsubr/wilds_other'\n", - "args_kw = \"-d camelyon17 --algorithm ERM --model densenet121 --split_scheme standard --groupby_fields hospital --loss_function cross_entropy --optimizer SGD --lr 0.0001 --batch_size 32 --weight_decay 0 --n_epochs 10 --scheduler ReduceLROnPlateau --scheduler_metric_split val --scheduler_metric_name acc_avg --log_dir log --log_every 50 --save_step 1000 --save_best --save_last --seed 0 --evaluate_all_splits --root_dir {}\".format(\n", - " ROOTDIR).split()\n", - "\n", - "parser = argparse.ArgumentParser()\n", - "\n", - "# Dataset\n", - "parser.add_argument('-d', '--dataset', choices=['encodeTFBS', 'amazon', 'camelyon17', 'celebA', 'civilcomments', 'iwildcam', 'waterbirds', 'yelp', 'poverty', 'fmow', 'ogbg-molpcba'], required=True)\n", - "parser.add_argument('--split_scheme', default='standard',\n", - " help='Identifies how the train/val/test split is constructed. Choices are dataset-specific.')\n", - "parser.add_argument('--dataset_kwargs', nargs='*', action=ParseKwargs, default={})\n", - "parser.add_argument('--root_dir', default=None, required=True,\n", - " help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).')\n", - "parser.add_argument('--download', default=False, action='store_true',\n", - " help='If true, tries to downloads the dataset if it does not exist in root_dir.')\n", - "parser.add_argument('--frac', type=float, default=1.0,\n", - " help='Convenience parameter that scales all dataset splits down to the specified fraction, for development purposes.')\n", - "\n", - "# Loaders\n", - "parser.add_argument('--train_loader', choices=['standard', 'group'], default='standard')\n", - "parser.add_argument('--train_loader_kwargs', nargs='*', action=ParseKwargs, default={})\n", - "parser.add_argument('--eval_loader', choices=['standard', 'group'], default='standard')\n", - "parser.add_argument('--batch_size', type=int, default=32)\n", - "parser.add_argument('--no_pin_memory', action='store_true') # TODO: put as loader_kwargs\n", - "parser.add_argument('--num_workers', type=int, default=4) # TODO: put as loader kwargs\n", - "\n", - "# Model\n", - "parser.add_argument(\n", - " '--model',\n", - " choices=['bert-base-uncased', 'inception_v3', 'densenet121', 'wideresnet50', 'resnet50', 'gin-virtual', 'resnet18_ms'],\n", - " default='resnet50')\n", - "parser.add_argument('--model_kwargs', nargs='*', action=ParseKwargs, default={},\n", - " help='keyword arguments for model initialization passed as key1=value1 key2=value2')\n", - "parser.add_argument('--train_from_scratch', action='store_true', default=False)\n", - "\n", - "# Algorithm and objective\n", - "parser.add_argument('--algorithm', required=True, choices=algorithm_constructors.keys())\n", - "parser.add_argument('--algorithm_kwargs', nargs='*', action=ParseKwargs, default={})\n", - "parser.add_argument('--groupby_fields', nargs='+', default=None)\n", - "parser.add_argument('--loss_function', required=True, choices = losses.keys()) #TODO: make default\n", - "parser.add_argument('--val_metric', default=None)\n", - "\n", - "# Optimization\n", - "parser.add_argument('--n_epochs', type=int, default=4)\n", - "parser.add_argument('--optimizer', default=None, choices=optimizer_attributes.keys())\n", - "parser.add_argument('--lr', type=float, required=True)\n", - "parser.add_argument('--weight_decay', type=float, required=True)\n", - "parser.add_argument('--optimizer_kwargs', nargs='*', action=ParseKwargs, default={})\n", - "parser.add_argument('--scheduler', default=None, choices=scheduler_attributes.keys())\n", - "parser.add_argument('--scheduler_kwargs', nargs='*', action=ParseKwargs, default={})\n", - "parser.add_argument('--scheduler_metric_split', choices=['train', 'val'], default='val')\n", - "parser.add_argument('--scheduler_metric_name')\n", - "\n", - "# Evaluation\n", - "parser.add_argument('--evaluate_all_splits', action='store_true', default=False)\n", - "parser.add_argument('--additional_eval_splits', nargs='+', default=[])\n", - "\n", - "# Misc\n", - "parser.add_argument('--device', type=int, default=0)\n", - "parser.add_argument('--seed', type=int, default=0)\n", - "parser.add_argument('--log_dir', default='./logs')\n", - "parser.add_argument('--log_every', default=50, type=int)\n", - "parser.add_argument('--save_step', type=int, default=None)\n", - "parser.add_argument('--save_best', action='store_true', default=False)\n", - "parser.add_argument('--save_last', action='store_true', default=False)\n", - "parser.add_argument('--save_outputs', action='store_true', default=False)\n", - "parser.add_argument('--no_group_logging', action='store_true', default=False)\n", - "parser.add_argument('--val_metric_decreasing', action='store_true', default=False)\n", - "parser.add_argument('--use_wandb', action='store_true', default=False)\n", - "parser.add_argument('--progress_bar', action='store_true', default=False)\n", - "parser.add_argument('--resume', default=False, action='store_true')\n", - "parser.add_argument('--eval_only', default=False, action='store_true')\n", - "\n", - "args = parser.parse_args(args_kw)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# get_input (idx)" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": {}, - "outputs": [ - { - "ename": "NameError", - "evalue": "name '_metadata_df' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0midx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mthis_metadata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_metadata_df\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miloc\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mitime\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mflank_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m400\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mNameError\u001b[0m: name '_metadata_df' is not defined" - ] - } - ], - "source": [ - "idx = 3\n", - "this_metadata = _metadata_df.iloc[idx, :]\n", - "\n", - "itime = time.time()\n", - "flank_size = 400\n", - "interval_start = this_metadata['start'] - flank_size\n", - "interval_end = this_metadata['stop'] + flank_size\n", - "dnase_this = _dnase_allcelltypes[this_metadata['celltype']][this_metadata['chr']][interval_start:interval_end]\n", - "seq_this = _seq_bp[this_metadata['chr']][interval_start:interval_end]\n", - "data = np.column_stack([seq_this, dnase_this])\n", - "# print(time.time() - itime)" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "0.028102874755859375\n" - ] - } - ], - "source": [ - "itime = time.time()\n", - "metadata_array = torch.stack(\n", - " (torch.LongTensor(chr_ints), \n", - " torch.LongTensor(celltype_ints), \n", - " _y_array),\n", - " dim=1)\n", - "print(time.time() - itime)" - ] - }, - { - "cell_type": "code", - "execution_count": 34, - "metadata": {}, - "outputs": [ - { - "ename": "ModuleNotFoundError", - "evalue": "No module named 'torch_scatter'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m#data.shape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mwilds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata_loaders\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mget_train_loader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mget_eval_loader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m~/dr_benchmark/wilds/common/data_loaders.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mDataLoader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msampler\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mWeightedRandomSampler\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mSubsetRandomSampler\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mwilds\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcommon\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mget_counts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msplit_into_groups\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mIPython\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/dr_benchmark/wilds/common/utils.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch_scatter\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnumpy\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mSubset\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mpandas\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapi\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtypes\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mCategoricalDtype\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'torch_scatter'" - ] - } - ], - "source": [ - "#data.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 157, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "4600" - ] - }, - "execution_count": 157, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "data.shape\n", - "interval_end\n", - "# itime = time.time()\n", - "# np.save(os.path.join(data_dir, 'stmp.npy'), sa)\n", - "# print(time.time() - itime)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Run training experiment" - ] - }, - { - "cell_type": "code", - "execution_count": 167, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'python3 examples/run_expt.py -d encodeTFBS --algorithm ERM --model densenet121 --split_scheme standard --groupby_fields hospital --loss_function cross_entropy --optimizer SGD --lr 0.0001 --batch_size 32 --weight_decay 0 --n_epochs 10 --scheduler ReduceLROnPlateau --scheduler_metric_split val --scheduler_metric_name acc_avg --log_dir log --log_every 50 --save_step 1000 --save_best --save_last --seed 0 --evaluate_all_splits --root_dir ROOTDIR'" - ] - }, - "execution_count": 167, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "cmdstr = \"python3 examples/run_expt.py -d encodeTFBS --algorithm ERM --model densenet121 --split_scheme standard --groupby_fields hospital --loss_function cross_entropy\"\n", - "cmdstr += \" \"\n", - "cmdstr += \"--optimizer SGD --lr 0.0001 --batch_size 32 --weight_decay 0 --n_epochs 10 --scheduler ReduceLROnPlateau --scheduler_metric_split val --scheduler_metric_name acc_avg\"\n", - "cmdstr += \" \"\n", - "cmdstr += \"--log_dir log --log_every 50 --save_step 1000 --save_best --save_last --seed 0 --evaluate_all_splits --root_dir ROOTDIR\"\n", - "cmdstr" - ] - }, - { - "cell_type": "code", - "execution_count": 164, - "metadata": {}, - "outputs": [ - { - "ename": "NameError", - "evalue": "name '_metadata_array' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0m_metadata_array\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;31mNameError\u001b[0m: name '_metadata_array' is not defined" - ] - } - ], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.5" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/sandbox_model.ipynb b/sandbox_model.ipynb deleted file mode 100644 index 2d62b55e..00000000 --- a/sandbox_model.ipynb +++ /dev/null @@ -1,982 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Initialize dataset object" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "57.8772239685\n", - "66.8270189762\n" - ] - } - ], - "source": [ - "import numpy as np, pandas as pd, os, time, torch, torchvision\n", - "data_dir = '/oak/stanford/groups/akundaje/abalsubr/DREAM/wilds/codalab_archive/'\n", - "tf = 'MAX'\n", - "itime = time.time()\n", - "train_chr = pd.read_csv(os.path.join(data_dir, 'labels/{}.train.labels.tsv.gz'.format(tf)), sep='\\t')\n", - "print(time.time() - itime)\n", - "val_chr = pd.read_csv(os.path.join(data_dir, 'labels/{}.val.labels.tsv.gz'.format(tf)), sep='\\t')\n", - "print(time.time() - itime)" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']\n", - "val_celltype = ['A549']\n", - "test_celltype = ['GM12878']\n", - "all_celltypes = train_celltypes + val_celltype + test_celltype\n", - "\n", - "metadata_map = {}\n", - "metadata_map['chr'] = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX']\n", - "metadata_map['celltype'] = all_celltypes\n", - "\n", - "_split_dict = {\n", - " 'train': 0,\n", - " 'val-id': 1,\n", - " 'test': 2,\n", - " 'val-ood': 3\n", - "}\n", - "_split_names = {\n", - " 'train': 'Train',\n", - " 'val-id': 'Validation (ID)',\n", - " 'test': 'Test',\n", - " 'val-ood': 'Validation (OOD)'\n", - "}\n", - "_split_scheme = 'standard'" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "('H1-hESC', 25.299736976623535)\n", - "('HCT116', 49.68733310699463)\n", - "('HeLa-S3', 74.65905213356018)\n", - "('HepG2', 99.33112812042236)\n", - "('K562', 124.1327919960022)\n", - "('A549', 149.19999814033508)\n", - "('GM12878', 174.0277030467987)\n" - ] - } - ], - "source": [ - "itime = time.time()\n", - "sequence_filename = os.path.join(data_dir, 'sequence.npz')\n", - "seq_arr = np.load(sequence_filename)\n", - "print(time.time() - itime)\n", - "\n", - "itime = time.time()\n", - "_seq_bp = {}\n", - "for chrom in seq_arr:\n", - " _seq_bp[chrom] = seq_arr[chrom]\n", - " print(chrom, time.time() - itime)\n", - "itime = time.time()\n", - "_dnase_allcelltypes = {}\n", - "for ct in all_celltypes:\n", - " dnase_filename = os.path.join(data_dir, '{}_dnase.npz'.format(ct))\n", - " dnase_npz_file = np.load(dnase_filename)\n", - " _dnase_allcelltypes[ct] = {}\n", - " for chrom in _seq_bp:\n", - " _dnase_allcelltypes[ct][chrom] = dnase_npz_file[chrom]\n", - " print(ct, time.time() - itime)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class Beagle2(nn.Module):\n", - " \"\"\"\n", - " Neural net models over genomic sequence.\n", - " Input:\n", - " - sequence_length: int (default 1000) \n", - " - Shape: (N, 5, sequence_length, 1) with batch size N.\n", - " \n", - " Output:\n", - " - prediction (Tensor): float torch tensor of shape (N, )\n", - " \n", - " TODO: Finish docstring.\n", - " \"\"\"\n", - " def __init__(self):\n", - " \"\"\"\n", - " Parameters\n", - " ----------\n", - " sequence_length : int\n", - " n_genomic_features : int\n", - " \"\"\"\n", - " super(Beagle2, self).__init__()\n", - "\n", - " self.dropout = 0.3\n", - " self.num_cell_types = 1\n", - " self.conv1 = nn.Conv2d(5, 300, (19, 1), stride = (1, 1), padding=(9,0))\n", - " self.conv2 = nn.Conv2d(300, 200, (11, 1), stride = (1, 1), padding = (5,0))\n", - " self.conv3 = nn.Conv2d(200, 200, (7, 1), stride = (1, 1), padding = (4,0))\n", - " self.bn1 = nn.BatchNorm2d(300)\n", - " self.bn2 = nn.BatchNorm2d(200)\n", - " self.bn3 = nn.BatchNorm2d(200)\n", - " self.maxpool1 = nn.MaxPool2d((3, 1))\n", - " self.maxpool2 = nn.MaxPool2d((4, 1))\n", - " self.maxpool3 = nn.MaxPool2d((4, 1))\n", - "\n", - " self.fc1 = nn.Linear(4200, 1000)\n", - " self.bn4 = nn.BatchNorm1d(1000)\n", - "\n", - " self.fc2 = nn.Linear(1000, 1000)\n", - " self.bn5 = nn.BatchNorm1d(1000)\n", - "\n", - " self.fc3 = nn.Linear(1000, self.num_cell_types)\n", - "\n", - " def forward(self, s):\n", - " s = s.permute(0, 2, 1).contiguous() # batch_size x 4 x 1000\n", - " s = s.view(-1, 5, 1000, 1) # batch_size x 4 x 1000 x 1 [4 channels]\n", - " s = self.maxpool1(F.relu(self.bn1(self.conv1(s)))) # batch_size x 300 x 333 x 1\n", - " s = self.maxpool2(F.relu(self.bn2(self.conv2(s)))) # batch_size x 200 x 83 x 1\n", - " s = self.maxpool3(F.relu(self.bn3(self.conv3(s)))) # batch_size x 200 x 21 x 1\n", - " s = s.view(-1, 4200)\n", - " conv_out = s\n", - "\n", - " s = F.dropout(F.relu(self.bn4(self.fc1(s))), p=self.dropout, training=self.training) # batch_size x 1000\n", - " #s = F.dropout(F.relu(self.bn5(self.fc2(s))), p=self.dropout, training=self.training) # batch_size x 1000\n", - " \n", - " \n", - " s = self.fc3(s)\n", - "\n", - " return s, conv_out\n", - "\n", - "\n", - "class DanQ(nn.Module):\n", - " def __init__(self, sequence_length, n_genomic_features):\n", - " \"\"\"\n", - " Parameters\n", - " ----------\n", - " sequence_length : int\n", - " Input sequence length\n", - " n_genomic_features : int\n", - " Total number of features to predict\n", - " \"\"\"\n", - " super(DanQ, self).__init__()\n", - " self.nnet = nn.Sequential(\n", - " nn.Conv1d(4, 320, kernel_size=26),\n", - " nn.ReLU(inplace=True),\n", - " nn.MaxPool1d(\n", - " kernel_size=13, stride=13),\n", - " nn.Dropout(0.2))\n", - "\n", - " self.bdlstm = nn.Sequential(\n", - " nn.LSTM(\n", - " 320, 320, num_layers=1, batch_first=True, bidirectional=True))\n", - "\n", - " self._n_channels = math.floor(\n", - " (sequence_length - 25) / 13)\n", - " self.classifier = nn.Sequential(\n", - " nn.Dropout(0.5),\n", - " nn.Linear(self._n_channels * 640, 925),\n", - " nn.ReLU(inplace=True),\n", - " nn.Linear(925, n_genomic_features),\n", - " nn.Sigmoid())\n", - "\n", - " def forward(self, x):\n", - " \"\"\"Forward propagation of a batch.\n", - " \"\"\"\n", - " out = self.nnet(x)\n", - " reshape_out = out.transpose(0, 1).transpose(0, 2)\n", - " out, _ = self.bdlstm(reshape_out)\n", - " out = out.transpose(0, 1)\n", - " reshape_out = out.contiguous().view(\n", - " out.size(0), 640 * self._n_channels)\n", - " predict = self.classifier(reshape_out)\n", - " return predict\n", - "\n", - "\n", - "class DeepSEA(nn.Module):\n", - " def __init__(self, sequence_length, n_genomic_features):\n", - " \"\"\"\n", - " Parameters\n", - " ----------\n", - " sequence_length : int\n", - " n_genomic_features : int\n", - " \"\"\"\n", - " super(DeepSEA, self).__init__()\n", - " conv_kernel_size = 8\n", - " pool_kernel_size = 4\n", - "\n", - " self.conv_net = nn.Sequential(\n", - " nn.Conv1d(4, 320, kernel_size=conv_kernel_size),\n", - " nn.ReLU(inplace=True),\n", - " nn.MaxPool1d(\n", - " kernel_size=pool_kernel_size, stride=pool_kernel_size),\n", - " nn.Dropout(p=0.2),\n", - "\n", - " nn.Conv1d(320, 480, kernel_size=conv_kernel_size),\n", - " nn.ReLU(inplace=True),\n", - " nn.MaxPool1d(\n", - " kernel_size=pool_kernel_size, stride=pool_kernel_size),\n", - " nn.Dropout(p=0.2),\n", - "\n", - " nn.Conv1d(480, 960, kernel_size=conv_kernel_size),\n", - " nn.ReLU(inplace=True),\n", - " nn.Dropout(p=0.5))\n", - "\n", - " reduce_by = conv_kernel_size - 1\n", - " pool_kernel_size = float(pool_kernel_size)\n", - " self.n_channels = int(\n", - " np.floor(\n", - " (np.floor(\n", - " (sequence_length - reduce_by) / pool_kernel_size)\n", - " - reduce_by) / pool_kernel_size)\n", - " - reduce_by)\n", - " self.classifier = nn.Sequential(\n", - " nn.Linear(960 * self.n_channels, n_genomic_features),\n", - " nn.ReLU(inplace=True),\n", - " nn.Linear(n_genomic_features, n_genomic_features),\n", - " nn.Sigmoid())\n", - "\n", - " def forward(self, x):\n", - " \"\"\"Forward propagation of a batch.\n", - " \"\"\"\n", - " out = self.conv_net(x)\n", - " reshape_out = out.view(out.size(0), 960 * self.n_channels)\n", - " predict = self.classifier(reshape_out)\n", - " return predict" - ] - }, - { - "cell_type": "code", - "execution_count": 78, - "metadata": {}, - "outputs": [], - "source": [ - "import math\n", - "import numpy as np\n", - "import torch\n", - "import torch.nn as nn\n", - "import torch.nn.functional as F\n", - "\n", - "class Beagle(nn.Module):\n", - " \"\"\"\n", - " Neural net models over genomic sequence.\n", - " Input:\n", - " - sequence_length: int (default 1000) \n", - " - Shape: (N, 5, sequence_length, 1) with batch size N.\n", - " \n", - " Output:\n", - " - prediction (Tensor): float torch tensor of shape (N, )\n", - " \n", - " TODO: Finish docstring.\n", - " \"\"\"\n", - " def __init__(self):\n", - " \"\"\"\n", - " Parameters\n", - " ----------\n", - " sequence_length : int\n", - " n_genomic_features : int\n", - " \"\"\"\n", - " super(Beagle, self).__init__()\n", - "\n", - " self.dropout = 0.3\n", - " self.num_cell_types = 1\n", - " self.conv1 = nn.Conv2d(5, 300, (19, 1), stride = (1, 1), padding=(9,0))\n", - " self.conv2 = nn.Conv2d(300, 200, (11, 1), stride = (1, 1), padding = (5,0))\n", - " self.conv3 = nn.Conv2d(200, 200, (7, 1), stride = (1, 1), padding = (4,0))\n", - " self.bn1 = nn.BatchNorm2d(300)\n", - " self.bn2 = nn.BatchNorm2d(200)\n", - " self.bn3 = nn.BatchNorm2d(200)\n", - " self.maxpool1 = nn.MaxPool2d((3, 1))\n", - " self.maxpool2 = nn.MaxPool2d((4, 1))\n", - " self.maxpool3 = nn.MaxPool2d((4, 1))\n", - "\n", - " self.fc1 = nn.Linear(4200, 1000)\n", - " self.bn4 = nn.BatchNorm1d(1000)\n", - "\n", - " self.fc2 = nn.Linear(1000, 1000)\n", - " self.bn5 = nn.BatchNorm1d(1000)\n", - "\n", - " self.fc3 = nn.Linear(1000, self.num_cell_types)\n", - "\n", - " def forward(self, s):\n", - " s = s.permute(0, 2, 1).contiguous() # batch_size x 5 x 1000\n", - " s = s.view(-1, 5, 1000, 1) # batch_size x 5 x 1000 x 1 [5 channels]\n", - " s = self.maxpool1(F.relu(self.bn1(self.conv1(s)))) # batch_size x 300 x 333 x 1\n", - " s = self.maxpool2(F.relu(self.bn2(self.conv2(s)))) # batch_size x 200 x 83 x 1\n", - " s = self.maxpool3(F.relu(self.bn3(self.conv3(s)))) # batch_size x 200 x 21 x 1\n", - " s = s.view(-1, 4200)\n", - " conv_out = s\n", - "\n", - " s = F.dropout(F.relu(self.bn4(self.fc1(s))), p=self.dropout, training=self.training) # batch_size x 1000\n", - " s = F.dropout(F.relu(self.bn5(self.fc2(s))), p=self.dropout, training=self.training) # batch_size x 1000\n", - " \n", - " s = self.fc3(s)\n", - "\n", - " return s, conv_out" - ] - }, - { - "cell_type": "code", - "execution_count": 86, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[('nnet.0.weight', 33280),\n", - " ('nnet.0.bias', 320),\n", - " ('bdlstm.0.weight_ih_l0', 409600),\n", - " ('bdlstm.0.weight_hh_l0', 409600),\n", - " ('bdlstm.0.bias_ih_l0', 1280),\n", - " ('bdlstm.0.bias_hh_l0', 1280),\n", - " ('bdlstm.0.weight_ih_l0_reverse', 409600),\n", - " ('bdlstm.0.weight_hh_l0_reverse', 409600),\n", - " ('bdlstm.0.bias_ih_l0_reverse', 1280),\n", - " ('bdlstm.0.bias_hh_l0_reverse', 1280),\n", - " ('classifier.1.weight', 592000),\n", - " ('classifier.1.bias', 925),\n", - " ('classifier.3.weight', 4625),\n", - " ('classifier.3.bias', 5)]" - ] - }, - "execution_count": 86, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "def count_parameters(model):\n", - " return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", - "\n", - "model = Beagle2()\n", - "model = DanQ(50, 5)\n", - "\n", - "lst = [(x[0], x[1].numel()) for x in model.named_parameters()]\n", - "#np.sum([x[1] for x in lst])\n", - "count_parameters(model)\n", - "lst" - ] - }, - { - "cell_type": "code", - "execution_count": 48, - "metadata": {}, - "outputs": [ - { - "ename": "AttributeError", - "evalue": "'module' object has no attribute 'isin'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mtr_chrs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m'chr2'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr9'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr11'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mte_chrs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m'chr1'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr8'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'chr21'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mtraining_df\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'chr'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtr_chrs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mval_df\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mval_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mval_chr\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'chr'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mte_chrs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mall_df\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtraining_df\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_df\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mAttributeError\u001b[0m: 'module' object has no attribute 'isin'" - ] - } - ], - "source": [ - "tr_chrs = ['chr2', 'chr9', 'chr11']\n", - "te_chrs = ['chr1', 'chr8', 'chr21']\n", - "training_df = train_chr[np.isin(train_chr['chr'], tr_chrs)]\n", - "val_df = val_chr[np.isin(val_chr['chr'], te_chrs)]\n", - "all_df = pd.concat([training_df, val_df])\n", - "\n", - "#filter_msk = all_df['start'] >= 0\n", - "filter_msk = all_df['start']%1000 == 0\n", - "all_df = all_df[filter_msk]" - ] - }, - { - "cell_type": "code", - "execution_count": 49, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "1.12.1\n" - ] - } - ], - "source": [ - "print(np.__version__)" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/users/abalsubr/anaconda2/envs/scs3/lib/python3.6/site-packages/ipykernel_launcher.py:6: SettingWithCopyWarning: \n", - "A value is trying to be set on a copy of a slice from a DataFrame.\n", - "Try using .loc[row_indexer,col_indexer] = value instead\n", - "\n", - "See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy\n", - " \n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "1.659163236618042\n" - ] - } - ], - "source": [ - "itime = time.time()\n", - "pd_list = []\n", - "for ct in all_celltypes:\n", - " tc_chr = all_df[['chr', 'start', 'stop', ct]]\n", - " tc_chr.columns = ['chr', 'start', 'stop', 'y']\n", - " tc_chr['celltype'] = ct\n", - " pd_list.append(tc_chr)\n", - "metadata_df = pd.concat(pd_list)\n", - "print(time.time() - itime)" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "3.0391879081726074\n" - ] - } - ], - "source": [ - "itime = time.time()\n", - "y_array = metadata_df['y'].replace({'U': 0, 'B': 1, 'A': -1}).values\n", - "non_ambig_mask = (y_array != -1)\n", - "metadata_df['y'] = y_array\n", - "_metadata_df = metadata_df[non_ambig_mask]\n", - "_y_array = torch.LongTensor(y_array[non_ambig_mask])\n", - "print(time.time() - itime)" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "12.390011310577393\n" - ] - } - ], - "source": [ - "itime = time.time()\n", - "chr_ints = _metadata_df['chr'].replace(dict( [(y, x) for x, y in enumerate(metadata_map['chr'])] )).values\n", - "celltype_ints = _metadata_df['celltype'].replace(dict( [(y, x) for x, y in enumerate(metadata_map['celltype'])] )).values\n", - "print(time.time() - itime)" - ] - }, - { - "cell_type": "code", - "execution_count": 53, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/users/abalsubr/anaconda2/envs/scs3/lib/python3.6/site-packages/ipykernel_launcher.py:12: SettingWithCopyWarning: \n", - "A value is trying to be set on a copy of a slice from a DataFrame.\n", - "Try using .loc[row_indexer,col_indexer] = value instead\n", - "\n", - "See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy\n", - " if sys.path[0] == '':\n" - ] - } - ], - "source": [ - "train_chr_mask = np.isin(_metadata_df['chr'], tr_chrs)\n", - "val_chr_mask = np.isin(_metadata_df['chr'], te_chrs)\n", - "train_celltype_mask = np.isin(_metadata_df['celltype'], train_celltypes)\n", - "val_celltype_mask = np.isin(_metadata_df['celltype'], val_celltype)\n", - "test_celltype_mask = np.isin(_metadata_df['celltype'], test_celltype)\n", - "\n", - "split_array = -1*np.ones(_metadata_df.shape[0]).astype(int)\n", - "split_array[np.logical_and(train_chr_mask, train_celltype_mask)] = _split_dict['train']\n", - "split_array[np.logical_and(val_chr_mask, test_celltype_mask)] = _split_dict['test']\n", - "split_array[np.logical_and(val_chr_mask, val_celltype_mask)] = _split_dict['val-ood']\n", - "split_array[np.logical_and(val_chr_mask, train_celltype_mask)] = _split_dict['val-id']\n", - "_metadata_df['split'] = split_array\n", - "_split_array = split_array" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# get_input (idx)" - ] - }, - { - "cell_type": "code", - "execution_count": 153, - "metadata": {}, - "outputs": [], - "source": [ - "idx = 3\n", - "this_metadata = _metadata_df.iloc[idx, :]\n", - "\n", - "itime = time.time()\n", - "flank_size = 400\n", - "interval_start = this_metadata['start'] - flank_size\n", - "interval_end = this_metadata['stop'] + flank_size\n", - "dnase_this = _dnase_allcelltypes[this_metadata['celltype']][this_metadata['chr']][interval_start:interval_end]\n", - "seq_this = _seq_bp[this_metadata['chr']][interval_start:interval_end]\n", - "data = np.column_stack([seq_this, dnase_this])\n", - "# print(time.time() - itime)" - ] - }, - { - "cell_type": "code", - "execution_count": 154, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "4600" - ] - }, - "execution_count": 154, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "data.shape\n", - "interval_end\n", - "# itime = time.time()\n", - "# np.save(os.path.join(data_dir, 'stmp.npy'), sa)\n", - "# print(time.time() - itime)" - ] - }, - { - "cell_type": "code", - "execution_count": 78, - "metadata": {}, - "outputs": [ - { - "ename": "TypeError", - "evalue": "can't convert np.ndarray of type numpy.object_. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool.", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mitime\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m metadata_array = torch.stack(\n\u001b[0;32m----> 3\u001b[0;31m (torch.LongTensor(metadata_df['chr'].values), \n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mLongTensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmetadata_df\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'celltype'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m self._y_array),\n", - "\u001b[0;31mTypeError\u001b[0m: can't convert np.ndarray of type numpy.object_. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool." - ] - } - ], - "source": [ - "itime = time.time()\n", - "metadata_array = torch.stack(\n", - " (torch.LongTensor(chr_ints), \n", - " torch.LongTensor(celltype_ints), \n", - " _y_array),\n", - " dim=1)\n", - "print(time.time() - itime)" - ] - }, - { - "cell_type": "code", - "execution_count": 156, - "metadata": {}, - "outputs": [ - { - "ename": "NameError", - "evalue": "name '_metadata_array' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0m_metadata_array\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;31mNameError\u001b[0m: name '_metadata_array' is not defined" - ] - } - ], - "source": [ - "_metadata_array" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "from examples.models.model_attributes import model_attributes" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "ename": "ModuleNotFoundError", - "evalue": "No module named 'utils'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel_attributes\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mmodel_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mset_seed\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mLogger\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mCSVBatchLogger\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlog_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mParseKwargs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mload\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 21\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 22\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mdataset_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mexamples\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0moptimizer_attributes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/dr_benchmark/examples/train.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mutils\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0msave\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mIPython\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'utils'" - ] - } - ], - "source": [ - "import os, csv\n", - "import time\n", - "import argparse\n", - "import IPython\n", - "import pandas as pd\n", - "import torch\n", - "import torch.nn as nn\n", - "import torchvision\n", - "import sys\n", - "from collections import defaultdict\n", - "\n", - "# TODO: Replace this once we make wilds into an installed package\n", - "sys.path.insert(1, os.path.join(sys.path[0], '..'))\n", - "\n", - "from wilds.common.data_loaders import get_train_loader, get_eval_loader\n", - "from wilds.common.grouper import CombinatorialGrouper\n", - "from wilds.common.utils import get_counts\n", - "\n", - "from examples.models.model_attributes import model_attributes\n", - "from examples.utils import set_seed, Logger, CSVBatchLogger, log_args, ParseKwargs, load\n", - "from examples.train import train\n", - "from examples.data import dataset_attributes\n", - "from examples.optimizer import optimizer_attributes\n", - "from examples.scheduler import scheduler_attributes\n", - "from examples.loss import losses\n", - "from examples.utils import log_group_data\n", - "from examples.algorithms.constructors import algorithm_constructors\n", - "\n", - "\n", - "def initialize_algorithm(args, datasets, train_grouper):\n", - " train_dataset = datasets['train']['dataset']\n", - " train_loader = datasets['train']['loader']\n", - "\n", - " # Configure the final layer of the networks used\n", - " # The code below are defaults. Edit this if you need special config for your model.\n", - " if (train_dataset.is_classification) and (train_dataset.y_size == 1):\n", - " # For single-task classification, we have one output per class\n", - " d_out = train_dataset.n_classes\n", - " elif (not train_dataset.is_classification):\n", - " # For regression, we have one output per target dimension\n", - " d_out = train_dataset.y_size\n", - " else:\n", - " # TODO: Handle dataset-specific multi-task stuff here, e.g., for OGB\n", - " pass\n", - "\n", - " # Sanity checking input args\n", - " if args.algorithm == 'groupDRO':\n", - " assert args.train_loader_kwargs['uniform_over_groups']\n", - " elif args.algorithm in ['deepCORAL', 'IRM']:\n", - " assert args.train_loader == 'group'\n", - " assert args.train_loader_kwargs['uniform_over_groups']\n", - " assert args.train_loader_kwargs['distinct_groups']\n", - "\n", - " # Other config\n", - " n_train_steps = len(train_loader) * args.n_epochs\n", - " prediction_fn = dataset_attributes[args.dataset]['prediction_fn']\n", - " loss = losses[args.loss_function]\n", - " metric_constructor = dataset_attributes[args.dataset]['metric']\n", - " train_g = train_grouper.metadata_to_group(train_dataset.metadata_array)\n", - " is_group_in_train = get_counts(train_g, train_grouper.n_groups) > 0\n", - " algorithm_constructor = algorithm_constructors[args.algorithm]\n", - " algorithm = algorithm_constructor(\n", - " args=args,\n", - " d_out=d_out,\n", - " grouper=train_grouper,\n", - " prediction_fn=prediction_fn,\n", - " loss=loss,\n", - " metric_constructor=metric_constructor,\n", - " n_train_steps=n_train_steps,\n", - " is_group_in_train=is_group_in_train)\n", - " return algorithm" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "parser = argparse.ArgumentParser()\n", - "\n", - "# Dataset\n", - "parser.add_argument('-d', '--dataset', choices=dataset_attributes.keys(), required=True)\n", - "parser.add_argument('--split_scheme', default='standard',\n", - " help='Identifies how the train/val/test split is constructed. Choices are dataset-specific.')\n", - "parser.add_argument('--root_dir', default=None, required=True,\n", - " help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).')\n", - "parser.add_argument('--download', default=False, action='store_true',\n", - " help='If true, tries to downloads the dataset if it does not exist in root_dir.')\n", - "parser.add_argument('--frac', type=float, default=1.0,\n", - " help='Convenience parameter that scales all dataset splits down to the specified fraction, for development purposes.')\n", - "\n", - "# Loaders\n", - "parser.add_argument('--train_loader', choices=['standard', 'group'], default='standard')\n", - "parser.add_argument('--train_loader_kwargs', nargs='*', action=ParseKwargs, default={})\n", - "parser.add_argument('--eval_loader', choices=['standard', 'group'], default='standard')\n", - "parser.add_argument('--batch_size', type=int, default=32)\n", - "\n", - "# Model\n", - "parser.add_argument(\n", - " '--model',\n", - " choices=model_attributes.keys(),\n", - " default='resnet50')\n", - "parser.add_argument('--model_kwargs', nargs='*', action=ParseKwargs, default={},\n", - " help='keyword arguments for model initialization passed as key1=value1 key2=value2')\n", - "parser.add_argument('--train_from_scratch', action='store_true', default=False)\n", - "\n", - "# Algorithm and objective\n", - "parser.add_argument('--algorithm', required=True, choices=algorithm_constructors.keys())\n", - "parser.add_argument('--algorithm_kwargs', nargs='*', action=ParseKwargs, default={})\n", - "parser.add_argument('--groupby_fields', nargs='+', default=None)\n", - "parser.add_argument('--loss_function', required=True, choices = losses.keys()) #TODO: make default\n", - "parser.add_argument('--val_metric', default=None)\n", - "\n", - "# Optimization\n", - "parser.add_argument('--n_epochs', type=int, default=4)\n", - "parser.add_argument('--optimizer', default=None, choices=optimizer_attributes.keys())\n", - "parser.add_argument('--lr', type=float, required=True)\n", - "parser.add_argument('--weight_decay', type=float, required=True)\n", - "parser.add_argument('--optimizer_kwargs', nargs='*', action=ParseKwargs, default={})\n", - "parser.add_argument('--scheduler', default=None, choices=scheduler_attributes.keys())\n", - "parser.add_argument('--scheduler_kwargs', nargs='*', action=ParseKwargs, default={})\n", - "parser.add_argument('--scheduler_metric_split', choices=['train', 'val'], default='val')\n", - "parser.add_argument('--scheduler_metric_name')\n", - "\n", - "# Evaluation\n", - "parser.add_argument('--evaluate_all_splits', action='store_true', default=False)\n", - "parser.add_argument('--additional_eval_splits', nargs='+', default=[])\n", - "\n", - "# Misc\n", - "parser.add_argument('--device', default='cuda')\n", - "parser.add_argument('--seed', type=int, default=0)\n", - "parser.add_argument('--log_dir', default='./logs')\n", - "parser.add_argument('--log_every', default=50, type=int)\n", - "parser.add_argument('--save_step', type=int, default=None)\n", - "parser.add_argument('--save_best', action='store_true', default=False)\n", - "parser.add_argument('--save_last', action='store_true', default=False)\n", - "parser.add_argument('--save_outputs', action='store_true', default=False)\n", - "parser.add_argument('--no_group_logging', action='store_true', default=False)\n", - "\n", - "parser.add_argument('--resume', default=False, action='store_true')\n", - "\n", - "args = parser.parse_args()\n", - "\n", - "# Set defaults\n", - "if args.groupby_fields is None:\n", - " args.no_group_logging = True\n", - "if args.val_metric is None:\n", - " args.val_metric = dataset_attributes[args.dataset]['val_metric']\n", - "\n", - "## Initialize logs\n", - "if os.path.exists(args.log_dir) and args.resume:\n", - " resume=True\n", - " mode='a'\n", - "else:\n", - " resume=False\n", - " mode='w'\n", - "if not os.path.exists(args.log_dir):\n", - " os.makedirs(args.log_dir)\n", - "logger = Logger(os.path.join(args.log_dir, 'log.txt'), mode)\n", - "\n", - "# Record args\n", - "log_args(args, logger)\n", - "\n", - "# Set random seed\n", - "set_seed(args.seed)\n", - "\n", - "# Data\n", - "full_dataset = dataset_attributes[args.dataset]['constructor'](\n", - " root_dir=args.root_dir,\n", - " download=args.download,\n", - " split_scheme=args.split_scheme)\n", - "\n", - "# To implement data augmentation (i.e., have different transforms\n", - "# at training time vs. test time), modify these two lines:\n", - "train_transform = dataset_attributes[args.dataset]['transform'](args.model)\n", - "eval_transform = dataset_attributes[args.dataset]['transform'](args.model)\n", - "\n", - "train_grouper = CombinatorialGrouper(\n", - " dataset=full_dataset,\n", - " groupby_fields=args.groupby_fields)\n", - "\n", - "datasets = defaultdict(dict)\n", - "for split in full_dataset.split_dict.keys():\n", - " if split=='train':\n", - " transform = train_transform\n", - " verbose = True\n", - " elif split == 'val':\n", - " transform = eval_transform\n", - " verbose = True\n", - " else:\n", - " transform = eval_transform\n", - " verbose = False\n", - " # Get subset\n", - " datasets[split]['dataset'] = full_dataset.get_subset(\n", - " split,\n", - " frac=args.frac,\n", - " transform=transform)\n", - "\n", - " # Get loader\n", - " shared_loader_kwargs = {\n", - " 'num_workers': 4,\n", - " 'pin_memory': True,\n", - " 'batch_size': args.batch_size,\n", - " 'collate_fn': dataset_attributes[args.dataset]['collate']\n", - " }\n", - "\n", - " if split == 'train':\n", - " datasets[split]['loader'] = get_train_loader(\n", - " loader=args.train_loader,\n", - " dataset=datasets[split]['dataset'],\n", - " grouper=train_grouper,\n", - " train_loader_kwargs=args.train_loader_kwargs,\n", - " **shared_loader_kwargs)\n", - " else:\n", - " datasets[split]['loader'] = get_eval_loader(\n", - " loader=args.eval_loader,\n", - " dataset=datasets[split]['dataset'],\n", - " grouper=train_grouper,\n", - " **shared_loader_kwargs)\n", - "\n", - " # Set fields\n", - " datasets[split]['split'] = split\n", - " datasets[split]['name'] = full_dataset.split_names[split]\n", - " datasets[split]['verbose'] = verbose\n", - " # Loggers\n", - " datasets[split]['eval_logger'] = CSVBatchLogger(\n", - " os.path.join(args.log_dir, f'{split}_eval.csv'), mode=mode)\n", - " datasets[split]['algo_logger'] = CSVBatchLogger(\n", - " os.path.join(args.log_dir, f'{split}_algo.csv'), mode=mode)\n", - "\n", - "# Logging dataset info\n", - "if args.no_group_logging and full_dataset.is_classification and full_dataset.y_size==1:\n", - " log_grouper = CombinatorialGrouper(\n", - " dataset=full_dataset,\n", - " groupby_fields=['y'])\n", - "elif args.no_group_logging:\n", - " log_grouper = None\n", - "else:\n", - " log_grouper = train_grouper\n", - "log_group_data(args, datasets, log_grouper, logger)\n", - "\n", - "## Initialize algorithm\n", - "algorithm = initialize_algorithm(args, datasets, train_grouper)\n", - "\n", - "## Load saved results if resuming\n", - "if resume:\n", - " save_path = os.path.join(args.log_dir, 'last_model.pth')\n", - " prev_epoch, best_val_metric = load(algorithm, save_path)\n", - " epoch_offset = prev_epoch + 1\n", - "else:\n", - " epoch_offset=0\n", - " best_val_metric=None\n", - "\n", - "train(algorithm,\n", - " datasets,\n", - " logger,\n", - " args,\n", - " epoch_offset=epoch_offset,\n", - " best_val_metric=best_val_metric)\n", - "\n", - "logger.close()\n", - "for split in datasets:\n", - " datasets[split]['eval_logger'].close()\n", - " datasets[split]['algo_logger'].close()" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 2", - "language": "python", - "name": "python2" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 2 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.13" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/wilds/common/metrics/all_metrics.py b/wilds/common/metrics/all_metrics.py index bcf9ad7b..98fa2848 100644 --- a/wilds/common/metrics/all_metrics.py +++ b/wilds/common/metrics/all_metrics.py @@ -104,30 +104,45 @@ def _compute_group_wise(self, y_pred, y_true, g, n_groups): def worst(self, metrics): return minimum(metrics) -class Recall(Metric): - def __init__(self, prediction_fn=None, name=None, average='binary'): +# Break-even point of precision and recall. This is approx equal to average precision, and added due to currently open numerical issues with having zero true positives in a batch (see https://github.com/scikit-learn/scikit-learn/issues/8245 ). +def calc_PREven(y_true, y_score): + numpos = np.sum(y_true == 1) + top_ndces = np.argsort(y_score)[::-1] + ytr_ord = y_true[top_ndces] + m = {x: np.sum(ytr_ord[:numpos] == x) for x in np.unique(ytr_ord[:numpos])} + p = m[1] if 1 in m else 0 + if (len(m) == 0) or (0 not in m): + return None + else: + return 1.0*p/(m[0]+p) + +class MultiTaskPREven(MultiTaskMetric): + def __init__(self, prediction_fn=None, name=None): self.prediction_fn = prediction_fn if name is None: - name = f'recall' - if average is not None: - name+=f'-{average}' - self.average = average + name = f'preven' super().__init__(name=name) - def _compute(self, y_pred, y_true): + def _compute_flattened(self, flattened_y_pred, flattened_y_true): if self.prediction_fn is not None: - y_pred = self.prediction_fn(y_pred) - recall = sklearn.metrics.recall_score(y_true, y_pred, average=self.average, labels=torch.unique(y_true)) - return torch.tensor(recall) + flattened_y_pred = self.prediction_fn(flattened_y_pred) + ytr = np.array(flattened_y_true.squeeze().detach().cpu().numpy() > 0) + ypr = flattened_y_pred.squeeze().detach().cpu().numpy() + score = calc_PREven(ytr, ypr) + to_ret = torch.tensor(score).to(flattened_y_pred.device) + return to_ret + + def _compute(self, y_pred, y_true): + return self._compute_flattened(y_pred, y_true) def worst(self, metrics): return minimum(metrics) -class AveragePrecision(Metric): - def __init__(self, prediction_fn=None, name=None, average='macro'): +class Recall(Metric): + def __init__(self, prediction_fn=None, name=None, average='binary'): self.prediction_fn = prediction_fn if name is None: - name = f'avgprec' + name = f'recall' if average is not None: name+=f'-{average}' self.average = average @@ -136,17 +151,13 @@ def __init__(self, prediction_fn=None, name=None, average='macro'): def _compute(self, y_pred, y_true): if self.prediction_fn is not None: y_pred = self.prediction_fn(y_pred) - score = sklearn.metrics.average_precision_score( - np.array(y_true.squeeze().detach().cpu().numpy() > 0), - y_pred.squeeze().detach().cpu().numpy(), - average=self.average - ) - return torch.tensor(score) + recall = sklearn.metrics.recall_score(y_true, y_pred, average=self.average, labels=torch.unique(y_true)) + return torch.tensor(recall) def worst(self, metrics): return minimum(metrics) -class MTAveragePrecision(Metric): +class AveragePrecision(Metric): def __init__(self, prediction_fn=None, name=None, average='macro'): self.prediction_fn = prediction_fn if name is None: @@ -159,15 +170,12 @@ def __init__(self, prediction_fn=None, name=None, average='macro'): def _compute(self, y_pred, y_true): if self.prediction_fn is not None: y_pred = self.prediction_fn(y_pred) - ytr = np.array(torch.flatten(y_true.squeeze()).detach().cpu().numpy() > 0) - ypr = torch.flatten(y_pred.squeeze()).detach().cpu().numpy() score = sklearn.metrics.average_precision_score( - ytr, - ypr, + np.array(y_true.squeeze().detach().cpu().numpy() > 0), + y_pred.squeeze().detach().cpu().numpy(), average=self.average ) - to_ret = torch.tensor(score).to(y_pred.device) - return to_ret + return torch.tensor(score) def worst(self, metrics): return minimum(metrics) From 5429bca69dce2e51cccad2115cbc3c046d77822c Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Thu, 25 Mar 2021 11:06:15 -0700 Subject: [PATCH 109/244] no-op --- wilds/datasets/encodetfbs_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index aca43adb..26c0b4cb 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -131,7 +131,7 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' train_msk = (self._split_array == self._split_dict['train']) allzeroes_msk = (self._y_array.sum(axis=1) == 0).numpy() indices_to_keep = indices_to_keep & ~(train_msk & allzeroes_msk) - + self._metadata_df = self._metadata_df[indices_to_keep] self._split_array = self._split_array[indices_to_keep] self._y_array = self._y_array[indices_to_keep] From 677bc0055e8c28ee8701454d85d446ceb1a87727 Mon Sep 17 00:00:00 2001 From: Henrik Marklund Date: Thu, 25 Mar 2021 22:34:55 -0700 Subject: [PATCH 110/244] update compressed size of iwildcam --- wilds/datasets/iwildcam_dataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/wilds/datasets/iwildcam_dataset.py b/wilds/datasets/iwildcam_dataset.py index 533f7fbb..d98bbb72 100644 --- a/wilds/datasets/iwildcam_dataset.py +++ b/wilds/datasets/iwildcam_dataset.py @@ -40,7 +40,8 @@ class IWildCamDataset(WILDSDataset): _versions_dict = { '2.0': { 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x6313da2b204647e79a14b468131fcd64/contents/blob/', - 'compressed_size': 12_000_000_000}} + 'compressed_size': 11_957_420_032}} + def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'): From a4ed704ad53d58c73b09fe6334b666608e83a06f Mon Sep 17 00:00:00 2001 From: Berton Earnshaw Date: Sun, 28 Mar 2021 23:21:14 -0600 Subject: [PATCH 111/244] Add rxrx1 dataset --- examples/configs/datasets.py | 20 ++++ examples/transforms.py | 39 ++++++++ wilds/__init__.py | 1 + wilds/datasets/rxrx1_dataset.py | 158 ++++++++++++++++++++++++++++++++ wilds/get_dataset.py | 4 + 5 files changed, 222 insertions(+) create mode 100644 wilds/datasets/rxrx1_dataset.py diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index cd2d1d6f..c3ab5253 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -282,6 +282,26 @@ 'n_epochs': 4, 'process_outputs_function': None, }, + 'rxrx1': { + 'split_scheme': 'official', + 'model': 'resnet50', + 'model_kwargs': {'pretrained': True}, + 'train_transform': None, + 'eval_transform': None, + 'loss_function': 'cross_entropy', + 'groupby_fields': ['experiment'], # TODO what is this? + 'val_metric': 'accuracy', + 'val_metric_decreasing': False, + 'algo_log_metric': 'accuracy', + 'optimizer': 'Adam', + 'optimizer_kwargs': {}, + 'scheduler': None, # TODO cosine with warmup from transformers + 'batch_size': 1400, + 'lr': 1e-3, + 'weight_decay': 1e-5, + 'n_epochs': 60, + 'process_outputs_function': None, + }, } ########################################## diff --git a/examples/transforms.py b/examples/transforms.py index bafbd42f..2b14eedf 100644 --- a/examples/transforms.py +++ b/examples/transforms.py @@ -1,4 +1,7 @@ +import random + import torchvision.transforms as transforms +import torchvision.transforms.functional as TF from transformers import BertTokenizerFast, DistilBertTokenizerFast import torch @@ -13,6 +16,8 @@ def initialize_transform(transform_name, config, dataset): return initialize_image_resize_and_center_crop_transform(config, dataset) elif transform_name=='poverty_train': return initialize_poverty_train_transform() + elif transform_name=='rxrx1': + return initialize_rxrx1_transform(dataset) else: raise ValueError(f"{transform_name} not recognized") @@ -101,3 +106,37 @@ def transform_rgb(img): return img transform = transforms.Lambda(lambda x: transform_rgb(x)) return transform + + +def initialize_rxrx1_transform(dataset: str): + + def standardize(x: torch.Tensor) -> torch.Tensor: + mean = x.mean(dim=(1, 2)) + std = x.std(dim=(1, 2)) + std[std == 0.] = 1. + return TF.normalize(x, mean, std) + t_standardize = transforms.Lambda(lambda x: standardize(x)) + + def random_d8(x: torch.Tensor) -> torch.Tensor: + angle = random.choice([0, 90, 180, 270]) + if angle > 0: + x = TF.rotate(x, angle) + if random.random() < 0.5: + x = TF.hflip(x) + return x + t_random_d8 = transforms.Lambda(lambda x: random_d8(x)) + + if dataset == 'train': + transforms_ls = [ + t_random_d8, + transforms.ToTensor(), + t_standardize, + ] + elif dataset == 'test': + transforms_ls = [ + transforms.ToTensor(), + t_standardize, + ] + transform = transforms.Compose(transforms_ls) + + return transform diff --git a/wilds/__init__.py b/wilds/__init__.py index 77f0ad5a..cf76d81f 100644 --- a/wilds/__init__.py +++ b/wilds/__init__.py @@ -18,6 +18,7 @@ 'yelp', 'bdd100k', 'sqf', + 'rxrx1', ] supported_datasets = benchmark_datasets + additional_datasets diff --git a/wilds/datasets/rxrx1_dataset.py b/wilds/datasets/rxrx1_dataset.py new file mode 100644 index 00000000..fa1466a7 --- /dev/null +++ b/wilds/datasets/rxrx1_dataset.py @@ -0,0 +1,158 @@ +from datetime import datetime +import os +from pathlib import Path + +from PIL import Image +import pandas as pd +import numpy as np +import torch + +from wilds.datasets.wilds_dataset import WILDSDataset +from wilds.common.grouper import CombinatorialGrouper +from wilds.common.metrics.all_metrics import Accuracy, Recall, F1 + + +class RxRx1Dataset(WILDSDataset): + """ + The RxRx1 Dataset. + This is a modified version of the original RxRx1 dataset. + + Input (x): + 3-channel fluorescent microscopy images of cells + + Label (y): + y is one of 1,139 classes: + - 0 to 1107: treatment siRNAs + - 1108 to 1137: positive control siRNAs + - 1138: negative control siRNA + + Metadata: + Each image is annotated with its experiment, plate, well, and site, as + well as with the id of the siRNA the cells were perturbed with. + + Website: + https://www.rxrx.ai/rxrx1 + https://www.kaggle.com/c/recursion-cellular-image-classification + + Original publication: + @article{, + title={}, + author={}, + journal={}, + year={} + } + + License: + This work is licensed under a Creative Commons + Attribution-NonCommercial-ShareAlike 4.0 International License. To view + a copy of this license, visit + http://creativecommons.org/licenses/by-nc-sa/4.0/ or send a letter to + Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + """ + _dataset_name = 'rxrx1' + _versions_dict = { + '1.0': { + 'download_url': 'https://worksheets.codalab.org/rest/bundles/0xc01e117bb4504f988700408eaeeb16a8/contents/blob/', + 'compressed_size': 7_413_123_845} + } + + def __init__(self, version=None, root_dir='rxrx1-wilds', download=False, + split_scheme='official'): + + self._version = version + self._split_scheme = split_scheme + if self._split_scheme != 'official': + raise ValueError(f'Split scheme {self._split_scheme} not recognized') + + # path + self._data_dir = Path(self.initialize_data_dir(root_dir, download)) + + # Load splits + df = pd.read_csv(self._data_dir / 'metadata.csv') + + # Splits + split_dict = {'train': 0, 'test': 1} + self._split_array = df.dataset.apply(split_dict.get).values + + # Filenames + def create_filepath(row): + filepath = os.path.join(row.experiment, + f'Plate{row.plate}', + f'{row.well}_s{row.site}.png') + return filepath + self._input_array = df.apply(create_filepath, axis=1).values + + # Labels + self._y_array = torch.tensor(df['sirna_id'].values) + self._n_classes = max(df['sirna_id']) + 1 + self._y_size = 1 + assert len(np.unique(df['sirna_id'])) == self._n_classes + + # Location/group info + # FIXME need to enumerate experiments + # n_groups = max(df['location_remapped']) + 1 + # self._n_groups = n_groups + # assert len(np.unique(df['location_remapped'])) == self._n_groups + + # FIXME experiment and well are strings + self._metadata_array = torch.tensor( + np.stack([df['experiment'].values, + df['plate'].values, + df['well'].values, + df['site'].values, + self.y_array], axis=1) + ) + self._metadata_fields = ['experiment', 'plate', 'well', 'site', 'y'] + + # eval grouper + self._eval_grouper = CombinatorialGrouper( + dataset=self, + groupby_fields=(['experiment']) + ) + + super().__init__(root_dir, download, split_scheme) + + def eval(self, y_pred, y_true, metadata, prediction_fn=None): + """ + Computes all evaluation metrics. + Args: + - y_pred (Tensor): Predictions from a model. By default, they are + predicted labels (LongTensor). But they can also be other model + outputs such that prediction_fn(y_pred) are predicted labels. + - y_true (LongTensor): Ground-truth labels + - metadata (Tensor): Metadata + - prediction_fn (function): A function that turns y_pred into predicted labels + Output: + - results (dictionary): Dictionary of evaluation metrics + - results_str (str): String summarizing the evaluation metrics + """ + metrics = [ + Accuracy(prediction_fn=prediction_fn), + ] + + results = {} + + for i in range(len(metrics)): + results.update({ + **metrics[i].compute(y_pred, y_true), + }) + + results_str = ( + f"Average acc: {results[metrics[0].agg_metric_field]:.3f}\n" + ) + + return results, results_str + + def get_input(self, idx): + """ + Args: + - idx (int): Index of a data point + Output: + - x (Tensor): Input features of the idx-th data point + """ + + # All images are in the train folder + img_path = self.data_dir / self._input_array[idx] + img = Image.open(img_path) + + return img diff --git a/wilds/get_dataset.py b/wilds/get_dataset.py index 1073100f..2c5e2b19 100644 --- a/wilds/get_dataset.py +++ b/wilds/get_dataset.py @@ -77,3 +77,7 @@ def get_dataset(dataset, version=None, **dataset_kwargs): elif dataset == 'sqf': from wilds.datasets.sqf_dataset import SQFDataset return SQFDataset(version=version, **dataset_kwargs) + + elif dataset == 'rxrx1': + from wilds.datasets.rxrx1_dataset import RxRx1Dataset + return RxRx1Dataset(version=version, **dataset_kwargs) From a8b714344840bb942f2f4c84d3d2e661f7ad5eb3 Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Mon, 29 Mar 2021 10:54:44 -0700 Subject: [PATCH 112/244] Metadata, filepath, and batch size fixes --- examples/configs/datasets.py | 8 ++++---- wilds/datasets/rxrx1_dataset.py | 27 +++++++++++++++++---------- 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index c3ab5253..af9cfa8f 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -286,17 +286,17 @@ 'split_scheme': 'official', 'model': 'resnet50', 'model_kwargs': {'pretrained': True}, - 'train_transform': None, - 'eval_transform': None, + 'train_transform': 'image_base', + 'eval_transform': 'image_base', 'loss_function': 'cross_entropy', - 'groupby_fields': ['experiment'], # TODO what is this? + 'groupby_fields': ['experiment'], 'val_metric': 'accuracy', 'val_metric_decreasing': False, 'algo_log_metric': 'accuracy', 'optimizer': 'Adam', 'optimizer_kwargs': {}, 'scheduler': None, # TODO cosine with warmup from transformers - 'batch_size': 1400, + 'batch_size': 32, #1400, 'lr': 1e-3, 'weight_decay': 1e-5, 'n_epochs': 60, diff --git a/wilds/datasets/rxrx1_dataset.py b/wilds/datasets/rxrx1_dataset.py index fa1466a7..e5e04421 100644 --- a/wilds/datasets/rxrx1_dataset.py +++ b/wilds/datasets/rxrx1_dataset.py @@ -34,6 +34,7 @@ class RxRx1Dataset(WILDSDataset): https://www.rxrx.ai/rxrx1 https://www.kaggle.com/c/recursion-cellular-image-classification + FIXME Original publication: @article{, title={}, @@ -71,12 +72,15 @@ def __init__(self, version=None, root_dir='rxrx1-wilds', download=False, df = pd.read_csv(self._data_dir / 'metadata.csv') # Splits - split_dict = {'train': 0, 'test': 1} + # FIXME: Add validation + self._split_dict = {'train': 0, 'test': 1} + self._split_names = {'train': 'Train', 'test': 'Test'} self._split_array = df.dataset.apply(split_dict.get).values # Filenames def create_filepath(row): - filepath = os.path.join(row.experiment, + filepath = os.path.join('images', + row.experiment, f'Plate{row.plate}', f'{row.well}_s{row.site}.png') return filepath @@ -88,17 +92,20 @@ def create_filepath(row): self._y_size = 1 assert len(np.unique(df['sirna_id'])) == self._n_classes - # Location/group info - # FIXME need to enumerate experiments - # n_groups = max(df['location_remapped']) + 1 - # self._n_groups = n_groups - # assert len(np.unique(df['location_remapped'])) == self._n_groups + # Convert experiment and well from strings to idxs + indexed_metadata = {} + self._metadata_map = {} + for key in ['experiment', 'well']: + all_values = list(df[key].unique()) + value_to_idx_map = {value: idx for idx, value in enumerate(all_values)} + value_idxs = [value_to_idx_map[value] for value in df[key].tolist()] + self._metadata_map[key] = all_values + indexed_metadata[key] = value_idxs - # FIXME experiment and well are strings self._metadata_array = torch.tensor( - np.stack([df['experiment'].values, + np.stack([indexed_metadata['experiment'], df['plate'].values, - df['well'].values, + indexed_metadata['well'], df['site'].values, self.y_array], axis=1) ) From c9d7f1d2f3036d2a82124b81bd69204fa492969a Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Mon, 29 Mar 2021 11:17:26 -0700 Subject: [PATCH 113/244] Fix eval --- examples/configs/datasets.py | 4 ++-- wilds/datasets/rxrx1_dataset.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index af9cfa8f..4f5f74b7 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -290,7 +290,7 @@ 'eval_transform': 'image_base', 'loss_function': 'cross_entropy', 'groupby_fields': ['experiment'], - 'val_metric': 'accuracy', + 'val_metric': 'acc_avg', 'val_metric_decreasing': False, 'algo_log_metric': 'accuracy', 'optimizer': 'Adam', @@ -300,7 +300,7 @@ 'lr': 1e-3, 'weight_decay': 1e-5, 'n_epochs': 60, - 'process_outputs_function': None, + 'process_outputs_function': 'multiclass_logits_to_pred', }, } diff --git a/wilds/datasets/rxrx1_dataset.py b/wilds/datasets/rxrx1_dataset.py index e5e04421..6558d8e7 100644 --- a/wilds/datasets/rxrx1_dataset.py +++ b/wilds/datasets/rxrx1_dataset.py @@ -75,7 +75,9 @@ def __init__(self, version=None, root_dir='rxrx1-wilds', download=False, # FIXME: Add validation self._split_dict = {'train': 0, 'test': 1} self._split_names = {'train': 'Train', 'test': 'Test'} - self._split_array = df.dataset.apply(split_dict.get).values + self._split_array = df.dataset.apply(self._split_dict.get).values + # split_dict = {'train': 0, 'test': 1} + # self._split_array = df.dataset.apply(split_dict.get).values # Filenames def create_filepath(row): From 3dde9f9be15f44a90db84166fbc80db3455a7b35 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 30 Mar 2021 15:19:21 +0200 Subject: [PATCH 114/244] add initial faster implementation --- .gitignore | 3 +- .vscode/settings.json | 3 + examples/algorithms/single_model_algorithm.py | 2 +- examples/configs/datasets.py | 7 +- examples/configs/model.py | 6 + examples/losses.py | 10 + examples/models/detection/fasterrcnn.py | 510 ++++++++++++++++++ examples/models/initializer.py | 19 + examples/utils.py | 2 + wilds/common/metrics/all_metrics.py | 8 +- wilds/common/utils.py | 2 + wilds/datasets/gwhd_dataset.py | 5 +- 12 files changed, 564 insertions(+), 13 deletions(-) create mode 100644 .vscode/settings.json create mode 100644 examples/models/detection/fasterrcnn.py diff --git a/.gitignore b/.gitignore index e1076393..153471f1 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,5 @@ build dist wilds.egg-info data -logs \ No newline at end of file +logs +test_faster \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..d0c7592c --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "python.pythonPath": "/home/gdaubige/anaconda3/envs/wilds/bin/python" +} \ No newline at end of file diff --git a/examples/algorithms/single_model_algorithm.py b/examples/algorithms/single_model_algorithm.py index c5d1071b..94767742 100644 --- a/examples/algorithms/single_model_algorithm.py +++ b/examples/algorithms/single_model_algorithm.py @@ -51,7 +51,7 @@ def process_batch(self, batch): x = move_to(x, self.device) y_true = move_to(y_true, self.device) g = move_to(self.grouper.metadata_to_group(metadata), self.device) - outputs = self.model(x) + outputs = self.model(x, y_true) results = { 'g': g, diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index a8102723..78bc9cb2 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -284,15 +284,13 @@ }, 'gwhd': { 'split_scheme': 'official', - 'model': 'detr', + 'model': 'fasterrcnn', 'train_transform': 'image_base', 'eval_transform': 'image_base', 'model_kwargs': { - 'aux_loss': True, - 'n_queries': 200, 'n_classes': 1, 'pretrained': True}, - 'loss_function': 'detr_set_criterion', + 'loss_function': 'faster_criterion', 'groupby_fields': ['location'], 'val_metric': 'detection_accuracy_avg', # TODO 'val_metric_decreasing': False, @@ -304,7 +302,6 @@ 'lr': 1e-5, 'weight_decay': 1e-4, 'n_epochs': 50, - 'process_outputs_function': 'remove_detr_aux_outputs', 'loader_kwargs': { 'num_workers': 1, 'pin_memory': True, diff --git a/examples/configs/model.py b/examples/configs/model.py index 79b29841..31d755ef 100644 --- a/examples/configs/model.py +++ b/examples/configs/model.py @@ -66,5 +66,11 @@ # 'eos_coef': 0.1, 'eos_coef': 0.5, } + }, + 'fasterrcnn': { + 'model_kwargs': { + # Backbone. Always uses sine position embedding. + 'pretrained': True, + } } } diff --git a/examples/losses.py b/examples/losses.py index 3e5cd933..7fb1ae92 100644 --- a/examples/losses.py +++ b/examples/losses.py @@ -16,10 +16,20 @@ def initialize_loss(config, d_out): elif config.loss_function == 'detr_set_criterion': return ElementwiseLoss(loss_fn=get_detr_set_criterion(config, d_out)) + elif config.loss_function == 'faster_criterion': + return ElementwiseLoss(loss_fn=get_faster_criterion(config)) else: raise ValueError(f'config.loss_function {config.loss_function} not recognized') + +def get_faster_criterion(config): + from examples.models.detection.fasterrcnn import FasterRCNNLoss + + criterion = FasterRCNNLoss(config.device) + return criterion + + def get_detr_set_criterion(config, d_out): from examples.models.detr.matcher import HungarianMatcher from examples.models.detr.detr import SetCriterion diff --git a/examples/models/detection/fasterrcnn.py b/examples/models/detection/fasterrcnn.py new file mode 100644 index 00000000..a5aea7a7 --- /dev/null +++ b/examples/models/detection/fasterrcnn.py @@ -0,0 +1,510 @@ +import torch +import torch.nn as nn +import torchvision +from collections import OrderedDict +import torch +from torch import nn, Tensor +import warnings +from typing import Tuple, List, Dict, Optional, Union + +from torch import nn + + +import torchvision +from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, FasterRCNN +from torchvision.models.detection.backbone_utils import resnet_fpn_backbone +from torchvision.models.utils import load_state_dict_from_url + + +from torchvision.ops import misc as misc_nn_ops +from torchvision.ops import MultiScaleRoIAlign + + +from torchvision.models.detection.anchor_utils import AnchorGenerator +from torchvision.models.detection.generalized_rcnn import GeneralizedRCNN +from torchvision.models.detection.faster_rcnn import TwoMLPHead + +from torchvision.models.detection.rpn import RPNHead, RegionProposalNetwork, concat_box_prediction_layers,permute_and_flatten +from torchvision.models.detection.roi_heads import RoIHeads +from torchvision.models.detection.transform import GeneralizedRCNNTransform + +from torchvision.models.detection import _utils as det_utils +from torch.nn import functional as F + + +model_urls = { + 'fasterrcnn_resnet50_fpn_coco': + 'https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth', + 'fasterrcnn_mobilenet_v3_large_320_fpn_coco': + 'https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth', + 'fasterrcnn_mobilenet_v3_large_fpn_coco': + 'https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth' +} + +def batch_concat_box_prediction_layers(box_cls, box_regression): + # type: (List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor] + box_cls_flattened = [] + box_regression_flattened = [] + # for each feature level, permute the outputs to make them be in the + # same format as the labels. Note that the labels are computed for + # all feature levels concatenated, so we keep the same representation + # for the objectness and the box_regression + for box_cls_per_level, box_regression_per_level in zip( + box_cls, box_regression + ): + N, AxC, H, W = box_cls_per_level.shape + Ax4 = box_regression_per_level.shape[1] + A = Ax4 // 4 + C = AxC // A + box_cls_per_level = permute_and_flatten( + box_cls_per_level, N, A, C, H, W + ) + box_cls_flattened.append(box_cls_per_level) + + box_regression_per_level = permute_and_flatten( + box_regression_per_level, N, A, 4, H, W + ) + box_regression_flattened.append(box_regression_per_level) + # concatenate on the first dimension (representing the feature levels), to + # take into account the way the labels were generated (with all feature maps + # being concatenated as well) + + batch_size = box_regression_flattened[0].shape[0] + + new_box_cls = [] + new_box_regression = [] + for batch_idx in range(batch_size): + element_box_cls = [torch.unsqueeze(item[batch_idx],dim=0) for item in box_cls_flattened] + element_box_regression = [torch.unsqueeze(item[batch_idx],dim=0) for item in box_regression_flattened] + + element_box_cls = torch.cat(element_box_cls, dim=1).flatten(0, -2) + element_box_regression = torch.cat(element_box_regression, dim=1).reshape(-1, 4) + new_box_cls.append(element_box_cls) + new_box_regression.append(element_box_regression) + + + return new_box_cls, new_box_regression + +class RegionProposalNetworkWILDS(RegionProposalNetwork): + def __init__(self, + anchor_generator, + head, + # + fg_iou_thresh, bg_iou_thresh, + batch_size_per_image, positive_fraction, + # + pre_nms_top_n, post_nms_top_n, nms_thresh): + super().__init__(anchor_generator, + head, + fg_iou_thresh, bg_iou_thresh, + batch_size_per_image, positive_fraction, + pre_nms_top_n, post_nms_top_n, nms_thresh) + + def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets): + # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor] + """ + Arguments: + objectness (Tensor) + pred_bbox_deltas (Tensor) + labels (List[Tensor]) + regression_targets (List[Tensor]) + Returns: + objectness_loss (Tensor) + box_loss (Tensor) + """ + + + + + objectness, pred_bbox_deltas = batch_concat_box_prediction_layers(objectness, pred_bbox_deltas) + + objectness_loss = [] + box_loss = [] + + for objectness_, regression_targets_,labels_,objectness_,pred_bbox_deltas_ in zip(objectness,regression_targets,labels,objectness,pred_bbox_deltas): + + sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(torch.unsqueeze(labels_,dim=0)) + + sampled_pos_inds = torch.where(torch.cat(sampled_pos_inds, dim=0))[0] + sampled_neg_inds = torch.where(torch.cat(sampled_neg_inds, dim=0))[0] + sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0) + + + box_loss.append(det_utils.smooth_l1_loss( + pred_bbox_deltas_[sampled_pos_inds], + regression_targets_[sampled_pos_inds], + beta=1 / 9, + size_average=False, + ) / (sampled_inds.numel())) + + + objectness_loss.append(F.binary_cross_entropy_with_logits( + objectness_[sampled_inds].flatten(), labels_[sampled_inds] + )) + + return torch.stack(objectness_loss), torch.stack(box_loss) + def forward(self, + images, # type: ImageList + features, # type: Dict[str, Tensor] + targets=None # type: Optional[List[Dict[str, Tensor]]] + ): + # type: (...) -> Tuple[List[Tensor], Dict[str, Tensor]] + """ + Arguments: + images (ImageList): images for which we want to compute the predictions + features (OrderedDict[Tensor]): features computed from the images that are + used for computing the predictions. Each tensor in the list + correspond to different feature levels + targets (List[Dict[Tensor]]): ground-truth boxes present in the image (optional). + If provided, each element in the dict should contain a field `boxes`, + with the locations of the ground-truth boxes. + Returns: + boxes (List[Tensor]): the predicted boxes from the RPN, one Tensor per + image. + losses (Dict[Tensor]): the losses for the model during training. During + testing, it is an empty dict. + """ + # RPN uses all feature maps that are available + features = list(features.values()) + objectness, pred_bbox_deltas = self.head(features) + anchors = self.anchor_generator(images, features) + + num_images = len(anchors) + num_anchors_per_level_shape_tensors = [o[0].shape for o in objectness] + num_anchors_per_level = [s[0] * s[1] * s[2] for s in num_anchors_per_level_shape_tensors] + + raw_objectness = objectness + raw_pred_bbox_deltas = pred_bbox_deltas + objectness, pred_bbox_deltas = \ + concat_box_prediction_layers(objectness, pred_bbox_deltas) + # apply pred_bbox_deltas to anchors to obtain the decoded proposals + # note that we detach the deltas because Faster R-CNN do not backprop through + # the proposals + proposals = self.box_coder.decode(pred_bbox_deltas.detach(), anchors) + proposals = proposals.view(num_images, -1, 4) + boxes, scores = self.filter_proposals(proposals, objectness, images.image_sizes, num_anchors_per_level) + + losses = {} + assert targets is not None + labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets) + regression_targets = self.box_coder.encode(matched_gt_boxes, anchors) + loss_objectness, loss_rpn_box_reg = self.compute_loss( + raw_objectness, raw_pred_bbox_deltas, labels, regression_targets) + + losses = { + "loss_objectness": loss_objectness, + "loss_rpn_box_reg": loss_rpn_box_reg, + } + return boxes, losses + +def fastrcnn_loss(class_logits, box_regression, labels, regression_targets): + # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor] + """ + Computes the loss for Faster R-CNN. + Arguments: + class_logits (Tensor) + box_regression (Tensor) + labels (list[BoxList]) + regression_targets (Tensor) + Returns: + classification_loss (Tensor) + box_loss (Tensor) + """ + + + class_logits = torch.split(class_logits, 512,dim=0) + box_regression = torch.split(box_regression, 512,dim=0) + + classification_loss = [] + box_loss = [] + + for class_logits_, box_regression_, labels_, regression_targets_ in zip(class_logits, box_regression, labels, regression_targets): + + classification_loss.append(F.cross_entropy(class_logits_, labels_)) + # get indices that correspond to the regression targets for + # the corresponding ground truth labels, to be used with + # advanced indexing + sampled_pos_inds_subset = torch.where(labels_ > 0)[0] + + labels_pos = labels_[sampled_pos_inds_subset] + N, num_classes = class_logits_.shape + + + box_regression_ = box_regression_.reshape(N, -1, 4) + + box_loss_ = det_utils.smooth_l1_loss( + box_regression_[sampled_pos_inds_subset, labels_pos], + regression_targets_[sampled_pos_inds_subset], + beta=1 / 9, + size_average=False, + ) + box_loss.append(box_loss_ / labels_.numel()) + + return torch.stack(classification_loss), torch.stack(box_loss) + +class RoIHeadsWILDS(RoIHeads): + def __init__(self, box_roi_pool, box_head, box_predictor, box_fg_iou_thresh, box_bg_iou_thresh,box_batch_size_per_image,box_positive_fraction,bbox_reg_weights,box_score_thresh,box_nms_thresh,box_detections_per_img): + + + super().__init__(box_roi_pool, box_head, box_predictor, + box_fg_iou_thresh, box_bg_iou_thresh, + box_batch_size_per_image, box_positive_fraction, + bbox_reg_weights, + box_score_thresh, box_nms_thresh, box_detections_per_img) + + def forward(self, + features, # type: Dict[str, Tensor] + proposals, # type: List[Tensor] + image_shapes, # type: List[Tuple[int, int]] + targets=None # type: Optional[List[Dict[str, Tensor]]] + ): + # type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]] + """ + Arguments: + features (List[Tensor]) + proposals (List[Tensor[N, 4]]) + image_shapes (List[Tuple[H, W]]) + targets (List[Dict]) + """ + if targets is not None: + for t in targets: + # TODO: https://github.com/pytorch/pytorch/issues/26731 + floating_point_types = (torch.float, torch.double, torch.half) + assert t["boxes"].dtype in floating_point_types, 'target boxes must of float type' + assert t["labels"].dtype == torch.int64, 'target labels must of int64 type' + if self.has_keypoint(): + assert t["keypoints"].dtype == torch.float32, 'target keypoints must of float type' + + + # here batch is maintained + proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets) + + + box_features = self.box_roi_pool(features, proposals, image_shapes) # batch is maintained + box_features = self.box_head(box_features) + + class_logits, box_regression = self.box_predictor(box_features) + + result = torch.jit.annotate(List[Dict[str, torch.Tensor]], []) + losses = {} + assert labels is not None and regression_targets is not None + loss_classifier, loss_box_reg = fastrcnn_loss( + class_logits, box_regression, labels, regression_targets) + losses = { + "loss_classifier": loss_classifier, + "loss_box_reg": loss_box_reg + } + boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes) + num_images = len(boxes) + for i in range(num_images): + result.append( + { + "boxes": boxes[i], + "labels": labels[i], + "scores": scores[i], + } + ) + + + return result, losses + +def fasterrcnn_resnet50_fpn(pretrained=False, progress=True, + num_classes=91, pretrained_backbone=True, trainable_backbone_layers=3, **kwargs): + + assert trainable_backbone_layers <= 5 and trainable_backbone_layers >= 0 + # dont freeze any layers if pretrained model or backbone is not used + if not (pretrained or pretrained_backbone): + trainable_backbone_layers = 5 + if pretrained: + # no need to download the backbone if pretrained is set + pretrained_backbone = False + backbone = resnet_fpn_backbone('resnet50', pretrained_backbone, trainable_layers=trainable_backbone_layers) + model = FastWILDS(backbone, 91, **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls['fasterrcnn_resnet50_fpn_coco'], + progress=progress) + model.load_state_dict(state_dict) + + + # get number of input features for the classifier + in_features = model.roi_heads.box_predictor.cls_score.in_features + # replace the pre-trained head with a new one + model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes+1) + + + return model + + + +class FastWILDS(GeneralizedRCNN): + def __init__(self, backbone, num_classes=None, + # transform parameters + min_size=800, max_size=1333, + image_mean=None, image_std=None, + # RPN parameters + rpn_anchor_generator=None, rpn_head=None, + rpn_pre_nms_top_n_train=2000, rpn_pre_nms_top_n_test=1000, + rpn_post_nms_top_n_train=2000, rpn_post_nms_top_n_test=1000, + rpn_nms_thresh=0.7, + rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3, + rpn_batch_size_per_image=256, rpn_positive_fraction=0.5, + # Box parameters + box_roi_pool=None, box_head=None, box_predictor=None, + box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100, + box_fg_iou_thresh=0.5, box_bg_iou_thresh=0.5, + box_batch_size_per_image=512, box_positive_fraction=0.25, + bbox_reg_weights=None): + + if not hasattr(backbone, "out_channels"): + raise ValueError( + "backbone should contain an attribute out_channels " + "specifying the number of output channels (assumed to be the " + "same for all the levels)") + + assert isinstance(rpn_anchor_generator, (AnchorGenerator, type(None))) + assert isinstance(box_roi_pool, (MultiScaleRoIAlign, type(None))) + + if num_classes is not None: + if box_predictor is not None: + raise ValueError("num_classes should be None when box_predictor is specified") + else: + if box_predictor is None: + raise ValueError("num_classes should not be None when box_predictor " + "is not specified") + + out_channels = backbone.out_channels + + if rpn_anchor_generator is None: + anchor_sizes = ((32,), (64,), (128,), (256,), (512,)) + aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) + rpn_anchor_generator = AnchorGenerator( + anchor_sizes, aspect_ratios + ) + if rpn_head is None: + rpn_head = RPNHead( + out_channels, rpn_anchor_generator.num_anchors_per_location()[0] + ) + + rpn_pre_nms_top_n = dict(training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test) + rpn_post_nms_top_n = dict(training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test) + + rpn = RegionProposalNetworkWILDS( + rpn_anchor_generator, rpn_head, + rpn_fg_iou_thresh, rpn_bg_iou_thresh, + rpn_batch_size_per_image, rpn_positive_fraction, + rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh) + + if box_roi_pool is None: + box_roi_pool = MultiScaleRoIAlign( + featmap_names=['0', '1', '2', '3'], + output_size=7, + sampling_ratio=2) + + if box_head is None: + resolution = box_roi_pool.output_size[0] + representation_size = 1024 + box_head = TwoMLPHead( + out_channels * resolution ** 2, + representation_size) + + if box_predictor is None: + representation_size = 1024 + box_predictor = FastRCNNPredictor( + representation_size, + num_classes) + + roi_heads = RoIHeadsWILDS( + box_roi_pool, box_head, box_predictor, + box_fg_iou_thresh, box_bg_iou_thresh, + box_batch_size_per_image, box_positive_fraction, + bbox_reg_weights, + box_score_thresh, box_nms_thresh, box_detections_per_img) + + if image_mean is None: + image_mean = [0.485, 0.456, 0.406] + if image_std is None: + image_std = [0.229, 0.224, 0.225] + transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std) + + super(FastWILDS, self).__init__(backbone, rpn, roi_heads, transform) + # Set your own forward pass + def forward(self, images, targets=None): + + if targets is None: + raise ValueError("In training mode, targets should be passed") + assert targets is not None + for target in targets: + boxes = target["boxes"] + if isinstance(boxes, torch.Tensor): + if len(boxes.shape) != 2 or boxes.shape[-1] != 4: + raise ValueError("Expected target boxes to be a tensor" + "of shape [N, 4], got {:}.".format( + boxes.shape)) + else: + raise ValueError("Expected target boxes to be of type " + "Tensor, got {:}.".format(type(boxes))) + + original_image_sizes: List[Tuple[int, int]] = [] + for img in images: + val = img.shape[-2:] + assert len(val) == 2 + original_image_sizes.append((val[0], val[1])) + + images, targets = self.transform(images, targets) + + # Check for degenerate boxes + # TODO: Move this to a function + if targets is not None: + for target_idx, target in enumerate(targets): + boxes = target["boxes"] + degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] + if degenerate_boxes.any(): + # print the first degenerate box + bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0] + degen_bb: List[float] = boxes[bb_idx].tolist() + raise ValueError("All bounding boxes should have positive height and width." + " Found invalid box {} for target at index {}." + .format(degen_bb, target_idx)) + + features = self.backbone(images.tensors) + if isinstance(features, torch.Tensor): + features = OrderedDict([('0', features)]) + + proposals, proposal_losses = self.rpn(images, features, targets) + + + detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets) + + + detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes) + + for idx, det in enumerate(detections): + det["losses"] = {} + for k,v in proposal_losses.items(): + det["losses"][k] = v[idx] + for k,v in detector_losses.items(): + det["losses"][k] = v[idx] + return detections + + + + + + +class FasterRCNNLoss(nn.Module): + def __init__(self,device): + self.device = device + super().__init__() + + def forward(self, outputs, targets): + + + # loss values are loss_classifier loss_box_reg loss_objectness": loss_objectness, loss_rpn_box_reg + + elementwise_loss = torch.stack([sum(v for v in item["losses"].values()) for item in outputs]) + + + + return elementwise_loss + + diff --git a/examples/models/initializer.py b/examples/models/initializer.py index 0c6387a8..ab9942e6 100644 --- a/examples/models/initializer.py +++ b/examples/models/initializer.py @@ -82,6 +82,11 @@ def initialize_model(config, d_out, is_featurizer=False): raise NotImplementedError('Featurizer not implemented for detection yet') else: model = initialize_detr_model(config, d_out) + elif config.model == 'fasterrcnn': + if is_featurizer: # TODO + raise NotImplementedError('Featurizer not implemented for detection yet') + else: + model = initialize_fasterrcnn_model(config, d_out) else: raise ValueError(f'Model: {config.model} not recognized.') @@ -142,6 +147,20 @@ def initialize_torchvision_model(name, d_out, **kwargs): setattr(model, last_layer_name, last_layer) return model + +def initialize_fasterrcnn_model(config, d_out): + + from models.detection.fasterrcnn import fasterrcnn_resnet50_fpn + + # load a model pre-trained pre-trained on COCO + model = fasterrcnn_resnet50_fpn(pretrained=config.model_kwargs["pretrained"],num_classes=d_out) + + return model + + + + + def initialize_detr_model(config, d_out): from models.detr.backbone import Backbone, Joiner diff --git a/examples/utils.py b/examples/utils.py index 1a20ab45..c554c38e 100644 --- a/examples/utils.py +++ b/examples/utils.py @@ -184,6 +184,8 @@ def save_pred(y_pred, path_prefix): # Dictionary elif isinstance(y_pred, dict): torch.save(y_pred, path_prefix + '.pth') + elif isinstance(y_pred, list): + torch.save(y_pred, path_prefix + '.pth') else: raise TypeError("Invalid type for save_pred") diff --git a/wilds/common/metrics/all_metrics.py b/wilds/common/metrics/all_metrics.py index b1d8c021..84e592d3 100644 --- a/wilds/common/metrics/all_metrics.py +++ b/wilds/common/metrics/all_metrics.py @@ -181,13 +181,13 @@ def __init__(self, prediction_fn=None, iou_threshold=0.5,score_threshold=0.5, na def _compute_element_wise(self, y_pred ,y_true ): - batch_results = [] - for src_boxes, target_boxes, target_logits in zip( y_true, y_pred['pred_boxes'], y_pred['pred_logits']): - + for src_boxes, target in zip( y_true, y_pred): + target_boxes = target["boxes"] + target_scores = target["scores"] # Here should be prediction_fn ? - target_scores = F.softmax(target_logits, dim=1)[..., 0] + #target_scores = F.softmax(target_logits, dim=1)[..., 0] pred_boxes = target_boxes[target_scores > self.score_threshold] det_accuracy = self._accuracy(src_boxes["boxes"],pred_boxes,iou_threshold=self.iou_threshold) diff --git a/wilds/common/utils.py b/wilds/common/utils.py index e3b1440d..47446b06 100644 --- a/wilds/common/utils.py +++ b/wilds/common/utils.py @@ -81,6 +81,8 @@ def avg_over_groups(v, g, n_groups): group_avgs (Tensor): Vector of length num_groups group_counts (Tensor) """ + + assert v.device==g.device assert v.numel()==g.numel() group_count = get_counts(g, n_groups) diff --git a/wilds/datasets/gwhd_dataset.py b/wilds/datasets/gwhd_dataset.py index 151f9a0c..d21cbd5e 100644 --- a/wilds/datasets/gwhd_dataset.py +++ b/wilds/datasets/gwhd_dataset.py @@ -110,7 +110,7 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' torch.tensor([int(i) for i in box.split(" ")]) for box in boxes.split(";") ]), - "labels": torch.tensor([0]*len(list(boxes.split(";")))).long() + "labels": torch.tensor([1]*len(list(boxes.split(";")))).long() } if type(boxes) != float else { "boxes": torch.empty(0,4), # "labels": torch.empty(0,1,dtype=torch.long) @@ -121,6 +121,7 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' # The above boxes are (x_min,y_min,x_max,y_max) # Convert labels into (center_x, center_y, w, h) normalized, which is what DETR expects # TODO: If it's not standard, we can probably put this in a transform somewhere + """ for label in labels: boxes = label['boxes'] center_x = (boxes[:, 0] + boxes[:, 2]) / 2 / self._original_resolution[0] @@ -128,7 +129,7 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' width = (boxes[:, 2] - boxes[:, 0]) / self._original_resolution[0] height = (boxes[:, 3] - boxes[:, 1]) / self._original_resolution[1] label['boxes'] = torch.stack((center_x, center_y, width, height), dim=1) - + """ # num_boxes = [len(example['boxes']) for example in labels] # print(f'Max num_boxes is {max(num_boxes)}') From 2d97bcb81254e021287a10cef5904bdbdee47d1b Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 30 Mar 2021 15:20:06 +0200 Subject: [PATCH 115/244] remove vssettings --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 153471f1..16c552de 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,5 @@ dist wilds.egg-info data logs -test_faster \ No newline at end of file +test_faster +.vscode \ No newline at end of file From bcc73a98997dad39b269b92f2819b708a41bf573 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 30 Mar 2021 16:28:52 +0200 Subject: [PATCH 116/244] remove normalization within fasterrcnn --- debug_faster_WILDS.ipynb | 0 examples/configs/datasets.py | 4 ++-- examples/models/detection/fasterrcnn.py | 11 ++++++----- 3 files changed, 8 insertions(+), 7 deletions(-) create mode 100644 debug_faster_WILDS.ipynb diff --git a/debug_faster_WILDS.ipynb b/debug_faster_WILDS.ipynb new file mode 100644 index 00000000..e69de29b diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index 78bc9cb2..c1bcb5bb 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -298,10 +298,10 @@ 'optimizer': 'Adam', 'optimizer_kwargs': {}, 'scheduler': None, - 'batch_size': 4, + 'batch_size': 8, 'lr': 1e-5, 'weight_decay': 1e-4, - 'n_epochs': 50, + 'n_epochs': 1, 'loader_kwargs': { 'num_workers': 1, 'pin_memory': True, diff --git a/examples/models/detection/fasterrcnn.py b/examples/models/detection/fasterrcnn.py index a5aea7a7..8863ab26 100644 --- a/examples/models/detection/fasterrcnn.py +++ b/examples/models/detection/fasterrcnn.py @@ -26,10 +26,10 @@ from torchvision.models.detection.rpn import RPNHead, RegionProposalNetwork, concat_box_prediction_layers,permute_and_flatten from torchvision.models.detection.roi_heads import RoIHeads -from torchvision.models.detection.transform import GeneralizedRCNNTransform from torchvision.models.detection import _utils as det_utils from torch.nn import functional as F +from torchvision.models.detection.transform import GeneralizedRCNNTransform model_urls = { @@ -41,6 +41,7 @@ 'https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth' } + def batch_concat_box_prediction_layers(box_cls, box_regression): # type: (List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor] box_cls_flattened = [] @@ -420,10 +421,9 @@ def __init__(self, backbone, num_classes=None, bbox_reg_weights, box_score_thresh, box_nms_thresh, box_detections_per_img) - if image_mean is None: - image_mean = [0.485, 0.456, 0.406] - if image_std is None: - image_std = [0.229, 0.224, 0.225] + + image_mean = [0., 0., 0.] # small trick because images are already normalized + image_std = [1., 1., 1.] transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std) super(FastWILDS, self).__init__(backbone, rpn, roi_heads, transform) @@ -450,6 +450,7 @@ def forward(self, images, targets=None): assert len(val) == 2 original_image_sizes.append((val[0], val[1])) + images, targets = self.transform(images, targets) # Check for degenerate boxes From 6f9f184eb7acbd39b794141d6df75cd42ceef8e3 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 30 Mar 2021 16:29:09 +0200 Subject: [PATCH 117/244] remove normalization within fasterrcnn --- debug_faster_WILDS.ipynb | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 debug_faster_WILDS.ipynb diff --git a/debug_faster_WILDS.ipynb b/debug_faster_WILDS.ipynb deleted file mode 100644 index e69de29b..00000000 From a31432496f666f37d2766a5a14b7c0d6efcdec46 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 30 Mar 2021 21:24:14 +0200 Subject: [PATCH 118/244] change roi head --- examples/configs/datasets.py | 2 +- examples/models/detection/fasterrcnn.py | 28 ++++++++++++++++--------- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index c1bcb5bb..f890c360 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -301,7 +301,7 @@ 'batch_size': 8, 'lr': 1e-5, 'weight_decay': 1e-4, - 'n_epochs': 1, + 'n_epochs': 10, 'loader_kwargs': { 'num_workers': 1, 'pin_memory': True, diff --git a/examples/models/detection/fasterrcnn.py b/examples/models/detection/fasterrcnn.py index 8863ab26..8167c1f4 100644 --- a/examples/models/detection/fasterrcnn.py +++ b/examples/models/detection/fasterrcnn.py @@ -278,23 +278,31 @@ def forward(self, # here batch is maintained - proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets) - + if self.training: + proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets) + else: + labels = None + regression_targets = None + matched_idxs = None - box_features = self.box_roi_pool(features, proposals, image_shapes) # batch is maintained + box_features = self.box_roi_pool(features, proposals, image_shapes) box_features = self.box_head(box_features) class_logits, box_regression = self.box_predictor(box_features) result = torch.jit.annotate(List[Dict[str, torch.Tensor]], []) losses = {} - assert labels is not None and regression_targets is not None - loss_classifier, loss_box_reg = fastrcnn_loss( - class_logits, box_regression, labels, regression_targets) - losses = { - "loss_classifier": loss_classifier, - "loss_box_reg": loss_box_reg - } + + if self.training: + assert labels is not None and regression_targets is not None + loss_classifier, loss_box_reg = fastrcnn_loss( + class_logits, box_regression, labels, regression_targets) + losses = { + "loss_classifier": loss_classifier, + "loss_box_reg": loss_box_reg + } + + boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes) num_images = len(boxes) for i in range(num_images): From 6eda4049fc6fedd44db40481a50b2852a819b0e4 Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Tue, 30 Mar 2021 15:13:36 -0700 Subject: [PATCH 119/244] add needs_y parameter --- examples/algorithms/single_model_algorithm.py | 10 ++++++++-- examples/configs/model.py | 3 ++- examples/models/initializer.py | 8 +++++--- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/examples/algorithms/single_model_algorithm.py b/examples/algorithms/single_model_algorithm.py index 94767742..fbb2dd80 100644 --- a/examples/algorithms/single_model_algorithm.py +++ b/examples/algorithms/single_model_algorithm.py @@ -51,7 +51,13 @@ def process_batch(self, batch): x = move_to(x, self.device) y_true = move_to(y_true, self.device) g = move_to(self.grouper.metadata_to_group(metadata), self.device) - outputs = self.model(x, y_true) + if self.model.needs_y: + if self.training: + outputs = self.model(x, y_true) + else: + outputs = self.model(x, None) + else: + outputs = self.model(x) results = { 'g': g, @@ -80,7 +86,7 @@ def evaluate(self, batch): """ assert not self.is_training results = self.process_batch(batch) - results['objective'] = self.objective(results).item() + results['objective'] = self.objective(results).item() self.update_log(results) return self.sanitize_dict(results) diff --git a/examples/configs/model.py b/examples/configs/model.py index 31d755ef..ca6962f3 100644 --- a/examples/configs/model.py +++ b/examples/configs/model.py @@ -71,6 +71,7 @@ 'model_kwargs': { # Backbone. Always uses sine position embedding. 'pretrained': True, - } + }, + 'needs_y': True } } diff --git a/examples/models/initializer.py b/examples/models/initializer.py index ab9942e6..f2b79f85 100644 --- a/examples/models/initializer.py +++ b/examples/models/initializer.py @@ -91,6 +91,11 @@ def initialize_model(config, d_out, is_featurizer=False): else: raise ValueError(f'Model: {config.model} not recognized.') + if config.model_kwargs.get('needs_y'): + model.needs_y = True + else: + model.needs_y = False + return model @@ -158,9 +163,6 @@ def initialize_fasterrcnn_model(config, d_out): return model - - - def initialize_detr_model(config, d_out): from models.detr.backbone import Backbone, Joiner From 965375b797a08ef3fbd3a8367bbe8b0913574bf6 Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Tue, 30 Mar 2021 15:26:31 -0700 Subject: [PATCH 120/244] remove needs_y from config and move into initializer instead --- examples/configs/model.py | 3 +-- examples/models/initializer.py | 11 +++++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/examples/configs/model.py b/examples/configs/model.py index ca6962f3..6e1a870f 100644 --- a/examples/configs/model.py +++ b/examples/configs/model.py @@ -71,7 +71,6 @@ 'model_kwargs': { # Backbone. Always uses sine position embedding. 'pretrained': True, - }, - 'needs_y': True + } } } diff --git a/examples/models/initializer.py b/examples/models/initializer.py index f2b79f85..c8ee9245 100644 --- a/examples/models/initializer.py +++ b/examples/models/initializer.py @@ -87,13 +87,16 @@ def initialize_model(config, d_out, is_featurizer=False): raise NotImplementedError('Featurizer not implemented for detection yet') else: model = initialize_fasterrcnn_model(config, d_out) - + model.needs_y = True else: raise ValueError(f'Model: {config.model} not recognized.') - if config.model_kwargs.get('needs_y'): - model.needs_y = True - else: + # The `needs_y` attribute specifies whether the model's forward function + # needs to take in both (x, y). + # If False, Algorithm.process_batch will call model(x). + # If True, Algorithm.process_batch() will call model(x, y) during training, + # and model(x, None) during eval. + if not hasattr(model, 'needs_y'): model.needs_y = False return model From b9ef3a2c5d73199e6c1671560d30509e9269719f Mon Sep 17 00:00:00 2001 From: aikanor Date: Tue, 30 Mar 2021 17:48:15 -0700 Subject: [PATCH 121/244] preprocessing fixes --- .../encode-tfbs/prep_metadata_labels.py | 42 +++++++++---------- wilds/datasets/encodetfbs_dataset.py | 34 +++++++++++++++ 2 files changed, 55 insertions(+), 21 deletions(-) diff --git a/dataset_preprocessing/encode-tfbs/prep_metadata_labels.py b/dataset_preprocessing/encode-tfbs/prep_metadata_labels.py index ca8c142f..006c07bb 100644 --- a/dataset_preprocessing/encode-tfbs/prep_metadata_labels.py +++ b/dataset_preprocessing/encode-tfbs/prep_metadata_labels.py @@ -10,19 +10,19 @@ _data_dir = '../../examples/data/encode-tfbs_v1.0/' -def write_label_bigwigs(celltypes): +def write_label_bigwigs( + celltypes, + train_suffix='train.labels.tsv.gz', + val_suffix='val.labels.tsv.gz' +): itime = time.time() tf_name = 'MAX' - _train_chroms = ['chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr10', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX'] - _val_chroms = ['chr2', 'chr9', 'chr11'] - _test_chroms = ['chr1', 'chr8', 'chr21'] - _all_chroms = _train_chroms + _val_chroms + _test_chroms # Read in metadata dataframe from training+validation data - train_regions_labeled = pd.read_csv(os.path.join(_data_dir, 'labels/{}.train.labels.tsv.gz'.format(tf_name)), sep='\t') - val_regions_labeled = pd.read_csv(os.path.join(_data_dir, 'labels/{}.val.labels.tsv.gz'.format(tf_name)), sep='\t') - training_df = train_regions_labeled# [np.isin(train_regions_labeled['chr'], _train_chroms)] - val_df = val_regions_labeled# [np.isin(val_regions_labeled['chr'], _test_chroms)] + train_regions_labeled = pd.read_csv(os.path.join(_data_dir, 'labels/{}.{}'.format(tf_name, train_suffix)), sep='\t') + val_regions_labeled = pd.read_csv(os.path.join(_data_dir, 'labels/{}.{}'.format(tf_name, val_suffix)), sep='\t') + training_df = train_regions_labeled + val_df = val_regions_labeled all_df = pd.concat([training_df, val_df]) # Get the y values, and remove labels by default. @@ -70,15 +70,19 @@ def write_label_bigwigs(celltypes): bw.close() -def write_metadata_products(celltypes, stride=6400, posamb_only=False): +def write_metadata_products( + celltypes, bed_df_filename='metadata_df.bed', y_arr_filename='metadata_y.npy', + stride=6400, posamb_only=False +): itime = time.time() tf_name = 'MAX' celltype_mdta = [] celltype_labels = [] - mdf_posamb = pd.read_csv( - _data_dir + 'labels/{}/{}_posamb.sorted.bed'.format(tf_name, tf_name), - sep='\t', header=None, index_col=None, names=['chr', 'start', 'stop', 'y', 'celltype'] - ) + if posamb_only: + mdf_posamb = pd.read_csv( + _data_dir + 'labels/{}/{}_posamb.sorted.bed'.format(tf_name, tf_name), + sep='\t', header=None, index_col=None, names=['chr', 'start', 'stop', 'y', 'celltype'] + ) # Retrieve only the windows containing positively/ambiguously labeled bins (if posamb_only==True), or all windows (if posamb_only==False). for ct in celltypes: ct_labels_bw_path = _data_dir + "labels/{}/{}_{}.bigwig".format(tf_name, tf_name, ct) @@ -118,17 +122,13 @@ def write_metadata_products(celltypes, stride=6400, posamb_only=False): all_metadata_df = pd.concat(celltype_mdta) all_metadata_df.to_csv( - _data_dir + 'labels/{}/metadata_df.bed'.format(tf_name), + _data_dir + 'labels/{}/{}'.format(tf_name, bed_df_filename), sep='\t', header=False, index=False ) - np.save(_data_dir + 'labels/{}/metadata_y.npy'.format(tf_name), np.vstack(celltype_labels)) + np.save(_data_dir + 'labels/{}/{}'.format(tf_name, y_arr_filename), np.vstack(celltype_labels)) if __name__ == '__main__': - _train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562'] - _val_celltype = ['A549'] - _test_celltype = ['GM12878'] - _all_celltypes = _train_celltypes + _val_celltype + _test_celltype + _all_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562', 'A549', 'GM12878'] write_label_bigwigs(_all_celltypes) write_metadata_products(_all_celltypes) - \ No newline at end of file diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index e2634bd6..8c785c4d 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -118,6 +118,40 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' 'val': 'Validation (OOD)', 'test': 'Test', } + elif self._split_scheme == 'challenge': + ch_train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'K562', 'A549', 'GM12878'] + ch_val_celltype = ['HepG2'] + ch_test_celltype = ['liver'] + splits = { + 'train': { + 'chroms': train_chroms, + 'celltypes': ch_train_celltypes + }, + 'id_val': { + 'chroms': val_chroms, + 'celltypes': ch_train_celltypes + }, + 'val': { + 'chroms': val_chroms, + 'celltypes': ch_val_celltype + }, + 'test': { + 'chroms': test_chroms, + 'celltypes': ch_test_celltype + }, + } + self._split_dict = { + 'train': 0, + 'val': 1, + 'test': 2, + 'id_val': 3, + } + self._split_names = { + 'train': 'Train', + 'val': 'Validation (OOD)', + 'test': 'Test', + 'id_val': 'Validation (ID)', + } else: raise ValueError(f'Split scheme {self._split_scheme} not recognized') From 987c557ac03a5d9ceea983015ccec43fa407a245 Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Wed, 31 Mar 2021 10:12:13 -0700 Subject: [PATCH 122/244] spacing --- .../encode-tfbs/prep_accessibility.py | 6 ++-- .../encode-tfbs/prep_metadata_labels.py | 22 +++++++------ .../encode-tfbs/prep_sequence.py | 32 ++++++++++--------- 3 files changed, 33 insertions(+), 27 deletions(-) diff --git a/dataset_preprocessing/encode-tfbs/prep_accessibility.py b/dataset_preprocessing/encode-tfbs/prep_accessibility.py index 141981c0..514b025e 100644 --- a/dataset_preprocessing/encode-tfbs/prep_accessibility.py +++ b/dataset_preprocessing/encode-tfbs/prep_accessibility.py @@ -3,7 +3,9 @@ import pyBigWig # Human chromosome names -chr_IDs = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX'] +chr_IDs = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', + 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', + 'chr20', 'chr21', 'chr22', 'chrX'] def generate_accessibility_archives(input_dir='dnase_bigwigs', output_dir='codalab_archive'): dnases = {} @@ -38,4 +40,4 @@ def generate_accessibility_archives(input_dir='dnase_bigwigs', output_dir='codal generate_accessibility_archives( input_dir=args.input_dir, - output_dir=args.output_dir) \ No newline at end of file + output_dir=args.output_dir) diff --git a/dataset_preprocessing/encode-tfbs/prep_metadata_labels.py b/dataset_preprocessing/encode-tfbs/prep_metadata_labels.py index 006c07bb..2c8c0fb0 100644 --- a/dataset_preprocessing/encode-tfbs/prep_metadata_labels.py +++ b/dataset_preprocessing/encode-tfbs/prep_metadata_labels.py @@ -4,15 +4,17 @@ import pyBigWig # Human chromosome names -chr_IDs = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX'] +chr_IDs = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', + 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', + 'chr20', 'chr21', 'chr22', 'chrX'] chrom_sizes = {'chr1': 249250621, 'chr10': 135534747, 'chr11': 135006516, 'chr12': 133851895, 'chr13': 115169878, 'chr14': 107349540, 'chr15': 102531392, 'chr16': 90354753, 'chr17': 81195210, 'chr18': 78077248, 'chr19': 59128983, 'chr2': 243199373, 'chr20': 63025520, 'chr21': 48129895, 'chr22': 51304566, 'chr3': 198022430, 'chr4': 191154276, 'chr5': 180915260, 'chr6': 171115067, 'chr7': 159138663, 'chr8': 146364022, 'chr9': 141213431, 'chrX': 155270560} _data_dir = '../../examples/data/encode-tfbs_v1.0/' def write_label_bigwigs( - celltypes, - train_suffix='train.labels.tsv.gz', + celltypes, + train_suffix='train.labels.tsv.gz', val_suffix='val.labels.tsv.gz' ): itime = time.time() @@ -42,7 +44,7 @@ def write_label_bigwigs( _unsorted_dir = _data_dir + 'labels/{}/{}_posamb.bed'.format( tf_name, tf_name) _sorted_dir = _unsorted_dir.replace( - '{}_posamb'.format(tf_name), + '{}_posamb'.format(tf_name), '{}_posamb.sorted'.format(tf_name) ) _metadata_df.to_csv( @@ -53,10 +55,10 @@ def write_label_bigwigs( os.system('sort -k1,1 -k2,2n {} > {}'.format(_unsorted_dir, _sorted_dir)) mdf_posamb = pd.read_csv( - _sorted_dir, + _sorted_dir, sep='\t', header=None, index_col=None, names=['chr', 'start', 'stop', 'y', 'celltype'] ) - + # Write the binned labels to bigwig files - genome-wide labels chromsizes_list = [(k, v) for k, v in chrom_sizes.items()] for ct in celltypes: @@ -71,7 +73,7 @@ def write_label_bigwigs( def write_metadata_products( - celltypes, bed_df_filename='metadata_df.bed', y_arr_filename='metadata_y.npy', + celltypes, bed_df_filename='metadata_df.bed', y_arr_filename='metadata_y.npy', stride=6400, posamb_only=False ): itime = time.time() @@ -80,7 +82,7 @@ def write_metadata_products( celltype_labels = [] if posamb_only: mdf_posamb = pd.read_csv( - _data_dir + 'labels/{}/{}_posamb.sorted.bed'.format(tf_name, tf_name), + _data_dir + 'labels/{}/{}_posamb.sorted.bed'.format(tf_name, tf_name), sep='\t', header=None, index_col=None, names=['chr', 'start', 'stop', 'y', 'celltype'] ) # Retrieve only the windows containing positively/ambiguously labeled bins (if posamb_only==True), or all windows (if posamb_only==False). @@ -119,10 +121,10 @@ def write_metadata_products( print(ct, time.time() - itime) bw.close() print(time.time() - itime) - + all_metadata_df = pd.concat(celltype_mdta) all_metadata_df.to_csv( - _data_dir + 'labels/{}/{}'.format(tf_name, bed_df_filename), + _data_dir + 'labels/{}/{}'.format(tf_name, bed_df_filename), sep='\t', header=False, index=False ) np.save(_data_dir + 'labels/{}/{}'.format(tf_name, y_arr_filename), np.vstack(celltype_labels)) diff --git a/dataset_preprocessing/encode-tfbs/prep_sequence.py b/dataset_preprocessing/encode-tfbs/prep_sequence.py index 3ead9a27..b80be0da 100644 --- a/dataset_preprocessing/encode-tfbs/prep_sequence.py +++ b/dataset_preprocessing/encode-tfbs/prep_sequence.py @@ -6,14 +6,16 @@ # Sequence preprocessing. Code adapted from Jacob Schreiber. # Human chromosome names -chr_IDs = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX'] +chr_IDs = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', + 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', + 'chr20', 'chr21', 'chr22', 'chrX'] def one_hot_encode(sequence, ignore='N', alphabet=None, dtype='int8', verbose=False, **kwargs): """ - Converts a string or list of characters into a one-hot encoding. - This function will take in either a string or a list and convert it into a one-hot encoding. If the input is a string, each character is assumed to be a different symbol, e.g. 'ACGT' is assumed to be a sequence of four characters. If the input is a list, the elements can be any size. + Converts a string or list of characters into a one-hot encoding. + This function will take in either a string or a list and convert it into a one-hot encoding. If the input is a string, each character is assumed to be a different symbol, e.g. 'ACGT' is assumed to be a sequence of four characters. If the input is a list, the elements can be any size. Although this function will be used here primarily to convert nucleotide sequences into one-hot encoding with an alphabet of size 4, in principle this function can be used for any types of sequences. - + Parameters ---------- sequence : str or list @@ -28,29 +30,29 @@ def one_hot_encode(sequence, ignore='N', alphabet=None, dtype='int8', verbose=Fa Whether to display a progress bar. If a string is passed in, use as the name of the progressbar. Default is False. kwargs : arguments Arguments to be passed into tqdm. Default is None. - + Returns ------- ohe : numpy.ndarray A binary matrix of shape (alphabet_size, sequence_length) where alphabet_size is the number of unique elements in the sequence and sequence_length is the length of the input sequence. """ - + name = None if verbose in (True, False) else verbose d = verbose is False - + if isinstance(sequence, str): sequence = list(sequence) - + alphabet = alphabet or np.unique(sequence) alphabet = [char for char in alphabet if char != ignore] alphabet_lookup = {char: i for i, char in enumerate(alphabet)} - + ohe = np.zeros((len(sequence), len(alphabet)), dtype=dtype) for i, char in tqdm(enumerate(sequence), disable=d, desc=name, **kwargs): if char != ignore: idx = alphabet_lookup[char] ohe[i, idx] = 1 - + return ohe @@ -58,7 +60,7 @@ def read_fasta(filename, include_chroms=None, exclude_chroms=None, ignore='N', a """ Read in a FASTA file and output a dictionary of sequences. This function will take in the path to a FASTA-formatted file and output a string containing the sequence for each chromosome. Optionally, the user can specify a set of chromosomes to include or exclude from the returned dictionary. - + Parameters ---------- filename : str @@ -73,17 +75,17 @@ def read_fasta(filename, include_chroms=None, exclude_chroms=None, ignore='N', a A pre-defined alphabet. If None is passed in, the alphabet will be determined from the sequence, but this may be time consuming for large sequences. Must include the ignore character. Default is ['A', 'C', 'G', 'T', 'N']. verbose : bool or str, optional Whether to display a progress bar. If a string is passed in, use as the name of the progressbar. Default is False. - + Returns ------- chroms : dict A dictionary of strings where the keys are the names of the chromosomes (exact strings from the header lines in the FASTA file) and the values are the strings encoded there. """ - + sequences = {} name, sequence = None, None skip_chrom = False - + with open(filename, "r") as infile: for line in tqdm(infile, disable=not verbose): if line.startswith(">"): @@ -126,4 +128,4 @@ def generate_sequence_archive(seq_path='sequence/hg19.genome.fa', output_dir): generate_sequence_archive( seq_path=args.seq_path, - output_dir=args.output_dir) \ No newline at end of file + output_dir=args.output_dir) From ad2c60535f22599961b5b799211c9a854b4410cb Mon Sep 17 00:00:00 2001 From: aikanor Date: Wed, 31 Mar 2021 13:52:57 -0700 Subject: [PATCH 123/244] featurizer code --- .../encode-tfbs/prep_metadata_labels.py | 4 +- examples/configs/datasets.py | 2 +- examples/models/CNN_genome.py | 20 +++-- examples/models/initializer.py | 4 +- wilds/datasets/encodetfbs_dataset.py | 73 ++++++++++++++++++- 5 files changed, 89 insertions(+), 14 deletions(-) diff --git a/dataset_preprocessing/encode-tfbs/prep_metadata_labels.py b/dataset_preprocessing/encode-tfbs/prep_metadata_labels.py index 006c07bb..138a66dc 100644 --- a/dataset_preprocessing/encode-tfbs/prep_metadata_labels.py +++ b/dataset_preprocessing/encode-tfbs/prep_metadata_labels.py @@ -25,7 +25,7 @@ def write_label_bigwigs( val_df = val_regions_labeled all_df = pd.concat([training_df, val_df]) - # Get the y values, and remove labels by default. + # Get the y values, and remove negative labels by default. pd_list = [] for ct in celltypes: tc_chr = all_df[['chr', 'start', 'stop', ct]] @@ -50,8 +50,8 @@ def write_label_bigwigs( ) print(time.time() - itime) + # Sort bigwigs (as bed files) in order to convert to bigwig. os.system('sort -k1,1 -k2,2n {} > {}'.format(_unsorted_dir, _sorted_dir)) - mdf_posamb = pd.read_csv( _sorted_dir, sep='\t', header=None, index_col=None, names=['chr', 'start', 'stop', 'y', 'celltype'] diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index 91ddca27..7cc44ab7 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -120,7 +120,7 @@ 'scheduler': None, 'batch_size': 128, 'lr': 1e-4, - 'weight_decay': 1e-4, + 'weight_decay': 1e-2, 'n_epochs': 1, 'n_groups_per_batch': 2, 'algo_log_metric': 'multitask_binary_accuracy', diff --git a/examples/models/CNN_genome.py b/examples/models/CNN_genome.py index c767f61d..484632a0 100644 --- a/examples/models/CNN_genome.py +++ b/examples/models/CNN_genome.py @@ -25,10 +25,11 @@ def double_conv(in_channels, out_channels): class UNet(nn.Module): - # TODO: This is currently hard-coded to not use out_features - def __init__(self, out_features=16, n_channels_in=5): + def __init__(self, num_tasks=16, n_channels_in=5): super().__init__() - + + self.d_out = num_tasks + self.dconv_down1 = double_conv(n_channels_in, 15) self.dconv_down2 = double_conv(15, 22) self.dconv_down3 = double_conv(22, 33) @@ -58,7 +59,7 @@ def __init__(self, out_features=16, n_channels_in=5): def forward(self, x): # input_size = 12800 - # input_channels = 6 + # input_channels = 5 conv1 = self.dconv_down1(x) # Out: (input_size) x 15 x = self.maxpool(conv1) # (input_size / 2) x 15 @@ -100,7 +101,10 @@ def forward(self, x): x = self.dconv_up1(x) # (input_size) x 15 - # middle 128 bits - out = self.conv_last(x)[:, :, 64:192] - - return torch.squeeze(out) + x = self.conv_last(x) + + if self.d_out is None: + return x.shape[-1] # Default: 253 values + else: # middle 128 values + out = x[:, :, 64:192] + return torch.squeeze(out) diff --git a/examples/models/initializer.py b/examples/models/initializer.py index 4f6fef7e..4effde23 100644 --- a/examples/models/initializer.py +++ b/examples/models/initializer.py @@ -75,7 +75,9 @@ def initialize_model(config, d_out, is_featurizer=False): model = nn.Linear(out_features=d_out, **config.model_kwargs) elif config.model == 'unet-seq': if is_featurizer: - raise NotImplementedError("Featurizer not supported for UNet") + featurizer = UNet(out_features=None, **config.model_kwargs) + classifier = nn.Linear(featurizer.d_out, d_out) + model = (featurizer, classifier) else: model = UNet(out_features=d_out, **config.model_kwargs) else: diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index 8c785c4d..f6fe29cb 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -17,13 +17,15 @@ class EncodeTFBSDataset(WILDSDataset): 12800-base-pair regions of sequence with a quantified chromatin accessibility readout. Label (y): - y is a 128-bit vector, with each element y_i indicating the binding status of a 200bp window. It is 1 if this 200bp region is bound by the transcription factor, and 0 otherwise. If the window x starts at coordinate sc, y_i is the label of the window starting at coordinate (sc+3200)+(50*i). + y is a 128-bit vector, with each element y_i indicating the binding status of a 200bp window. It is 1 if this 200bp region is bound by the transcription factor, and 0 otherwise, for i = 0,1,...,127. + + Suppose the input window x starts at coordinate sc, extending until coordinate (sc+12800). Then y_i is the label of the window starting at coordinate (sc+3200)+(50*i). Metadata: Each sequence is annotated with the celltype of origin (a string) and the chromosome of origin (a string). Website: - https://www.synapse.org/#!Synapse:syn6131484 + https://www.synapse.org/#!Synapse:syn6131484 . This is the website for the challenge; the data can be downloaded from here into the meta """ _dataset_name = 'encode-tfbs' @@ -118,6 +120,7 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' 'val': 'Validation (OOD)', 'test': 'Test', } + # Add challenge splits, assuming 'liver' celltype is in the data. elif self._split_scheme == 'challenge': ch_train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'K562', 'A549', 'GM12878'] ch_val_celltype = ['HepG2'] @@ -152,6 +155,72 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' 'test': 'Test', 'id_val': 'Validation (ID)', } + elif self._split_scheme == 'challenge_in-dist': + ch_train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'K562', 'A549', 'GM12878'] + ch_val_celltype = ['HepG2'] + ch_test_celltype = ['liver'] + splits = { + 'train': { + 'chroms': train_chroms, + 'celltypes': ch_test_celltype, + }, + 'val': { + 'chroms': val_chroms, + 'celltypes': ch_test_celltype + }, + 'test': { + 'chroms': test_chroms, + 'celltypes': ch_test_celltype + }, + } + self._split_dict = { + 'train': 0, + 'val': 1, + 'test': 2, + } + self._split_names = { + 'train': 'Train', + 'val': 'Validation (OOD)', + 'test': 'Test', + } + # Add new split scheme specifying custom test and val celltypes in the format test..val.. + elif '.' in self._split_scheme: + all_celltypes = train_celltypes + val_celltype + test_celltype + in_val_ct = self._split_scheme.split('.')[1] + in_test_ct = self._split_scheme.split('.')[3] + train_celltypes = [ct for ct in all_celltypes if ((ct != in_val_ct) and (ct != in_test_ct))] + val_celltype = [in_val_ct] + test_celltype = [in_test_ct] + splits = { + 'train': { + 'chroms': train_chroms, + 'celltypes': train_celltypes + }, + 'id_val': { + 'chroms': val_chroms, + 'celltypes': train_celltypes + }, + 'val': { + 'chroms': val_chroms, + 'celltypes': val_celltype + }, + 'test': { + 'chroms': test_chroms, + 'celltypes': test_celltype + }, + } + self._split_dict = { + 'train': 0, + 'val': 1, + 'test': 2, + 'id_val': 3, + } + self._split_names = { + 'train': 'Train', + 'val': 'Validation (OOD)', + 'test': 'Test', + 'id_val': 'Validation (ID)', + } else: raise ValueError(f'Split scheme {self._split_scheme} not recognized') From 81d4e6a6a679d60e68d2f2c6e183a22ec357356b Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 1 Apr 2021 08:28:24 +0200 Subject: [PATCH 124/244] remove use of targets during eval + new splits --- .gitignore | 4 +- examples/algorithms/single_model_algorithm.py | 10 +++- examples/configs/datasets.py | 2 +- examples/models/detection/fasterrcnn.py | 58 +++++++++++-------- examples/models/initializer.py | 2 +- examples/train.py | 2 + wilds/datasets/gwhd_dataset.py | 30 ++++------ 7 files changed, 61 insertions(+), 47 deletions(-) diff --git a/.gitignore b/.gitignore index 16c552de..388819f7 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,6 @@ wilds.egg-info data logs test_faster -.vscode \ No newline at end of file +paper +.vscode +*sh \ No newline at end of file diff --git a/examples/algorithms/single_model_algorithm.py b/examples/algorithms/single_model_algorithm.py index 94767742..753245b4 100644 --- a/examples/algorithms/single_model_algorithm.py +++ b/examples/algorithms/single_model_algorithm.py @@ -51,8 +51,16 @@ def process_batch(self, batch): x = move_to(x, self.device) y_true = move_to(y_true, self.device) g = move_to(self.grouper.metadata_to_group(metadata), self.device) - outputs = self.model(x, y_true) + + if self.model.needs_y: + if self.training: + outputs = self.model(x, y_true) + else: + outputs = self.model(x, None) + else: + outputs = self.model(x) + results = { 'g': g, 'y_true': y_true, diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index f890c360..92bd1ab3 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -298,7 +298,7 @@ 'optimizer': 'Adam', 'optimizer_kwargs': {}, 'scheduler': None, - 'batch_size': 8, + 'batch_size': 4, 'lr': 1e-5, 'weight_decay': 1e-4, 'n_epochs': 10, diff --git a/examples/models/detection/fasterrcnn.py b/examples/models/detection/fasterrcnn.py index 8167c1f4..d577090b 100644 --- a/examples/models/detection/fasterrcnn.py +++ b/examples/models/detection/fasterrcnn.py @@ -186,16 +186,18 @@ def forward(self, boxes, scores = self.filter_proposals(proposals, objectness, images.image_sizes, num_anchors_per_level) losses = {} - assert targets is not None - labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets) - regression_targets = self.box_coder.encode(matched_gt_boxes, anchors) - loss_objectness, loss_rpn_box_reg = self.compute_loss( - raw_objectness, raw_pred_bbox_deltas, labels, regression_targets) - - losses = { - "loss_objectness": loss_objectness, - "loss_rpn_box_reg": loss_rpn_box_reg, - } + + if self.training: + assert targets is not None + labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets) + regression_targets = self.box_coder.encode(matched_gt_boxes, anchors) + loss_objectness, loss_rpn_box_reg = self.compute_loss( + raw_objectness, raw_pred_bbox_deltas, labels, regression_targets) + + losses = { + "loss_objectness": loss_objectness, + "loss_rpn_box_reg": loss_rpn_box_reg, + } return boxes, losses def fastrcnn_loss(class_logits, box_regression, labels, regression_targets): @@ -437,20 +439,22 @@ def __init__(self, backbone, num_classes=None, super(FastWILDS, self).__init__(backbone, rpn, roi_heads, transform) # Set your own forward pass def forward(self, images, targets=None): + - if targets is None: - raise ValueError("In training mode, targets should be passed") - assert targets is not None - for target in targets: - boxes = target["boxes"] - if isinstance(boxes, torch.Tensor): - if len(boxes.shape) != 2 or boxes.shape[-1] != 4: - raise ValueError("Expected target boxes to be a tensor" - "of shape [N, 4], got {:}.".format( - boxes.shape)) - else: - raise ValueError("Expected target boxes to be of type " - "Tensor, got {:}.".format(type(boxes))) + if self.training: + if targets is None: + raise ValueError("In training mode, targets should be passed") + assert targets is not None + for target in targets: + boxes = target["boxes"] + if isinstance(boxes, torch.Tensor): + if len(boxes.shape) != 2 or boxes.shape[-1] != 4: + raise ValueError("Expected target boxes to be a tensor" + "of shape [N, 4], got {:}.".format( + boxes.shape)) + else: + raise ValueError("Expected target boxes to be of type " + "Tensor, got {:}.".format(type(boxes))) original_image_sizes: List[Tuple[int, int]] = [] for img in images: @@ -493,6 +497,8 @@ def forward(self, images, targets=None): det["losses"][k] = v[idx] for k,v in detector_losses.items(): det["losses"][k] = v[idx] + + return detections @@ -509,8 +515,10 @@ def forward(self, outputs, targets): # loss values are loss_classifier loss_box_reg loss_objectness": loss_objectness, loss_rpn_box_reg - - elementwise_loss = torch.stack([sum(v for v in item["losses"].values()) for item in outputs]) + try: + elementwise_loss = torch.stack([sum(v for v in item["losses"].values()) for item in outputs]) + except: + elementwise_loss = torch.ones(len(outputs)).to(self.device) diff --git a/examples/models/initializer.py b/examples/models/initializer.py index ab9942e6..92db0278 100644 --- a/examples/models/initializer.py +++ b/examples/models/initializer.py @@ -87,7 +87,7 @@ def initialize_model(config, d_out, is_featurizer=False): raise NotImplementedError('Featurizer not implemented for detection yet') else: model = initialize_fasterrcnn_model(config, d_out) - + model.needs_y = True else: raise ValueError(f'Model: {config.model} not recognized.') diff --git a/examples/train.py b/examples/train.py index 22028ded..963049e3 100644 --- a/examples/train.py +++ b/examples/train.py @@ -62,6 +62,8 @@ def run_epoch(algorithm, dataset, general_logger, epoch, config, train): epoch_y_pred = collate_list(epoch_y_pred) epoch_y_true = collate_list(epoch_y_true) epoch_metadata = collate_list(epoch_metadata) + + results, results_str = dataset['dataset'].eval( epoch_y_pred, epoch_y_true, diff --git a/wilds/datasets/gwhd_dataset.py b/wilds/datasets/gwhd_dataset.py index d21cbd5e..b2b98e7e 100644 --- a/wilds/datasets/gwhd_dataset.py +++ b/wilds/datasets/gwhd_dataset.py @@ -27,9 +27,7 @@ class GWHDDataset(WILDSDataset): 'official' for WILDS related tasks. To reproduce the baseline, several splits are needed: - to train a model on train domains and test against a all test split: 'train_in-dist' - - to train a model on a portion of a specific val or test domain and test it against the remaining portion: - "{domain}_in-dist" where domain is the id of a domain (usask_1, uq_1, utokyo_1, utokyo_2, nau_1) - no validation datasets are accessible for the baseline splits + - "benchmark_biased" ; "benchmark_in-dist" Input (x): 1024x1024 RGB images of wheat field canopy between flowering and ripening. Output (y): @@ -78,23 +76,19 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' # Get filenames if split_scheme =="official": - train_data_df = pd.read_csv(self.root / f'{split_scheme}_train.csv') - val_data_df = pd.read_csv(self.root / f'{split_scheme}_val.csv') - test_data_df = pd.read_csv(self.root / f'{split_scheme}_test.csv') + train_data_df = pd.read_csv(self.root / f'official_train.csv') + val_data_df = pd.read_csv(self.root / f'official_val.csv') + test_data_df = pd.read_csv(self.root / f'official_test.csv') - elif split_scheme == "train_in-dist": + elif split_scheme == "benchmark_biased": train_data_df = pd.read_csv(self.root / f'official_train.csv') - test_data_df = pd.read_csv(self.root / f'{split_scheme}_test.csv') - val_data_df = pd.DataFrame(columns=["image","labels","group"]) - elif split_scheme in [f"{domain}_in-dist" for domain in ["nau_1", "utokyo_1", "utokyo_2", "usask_1" , "uq_1"]]: - train_data_df = pd.read_csv(self.root / f'{split_scheme}_train.csv') - test_data_df = pd.read_csv(self.root / f'{split_scheme}_test.csv') - val_data_df = pd.DataFrame(columns=["image","labels","group"]) - - elif split_scheme == "in-dist": - train_data_df = pd.read_csv(self.root / f'{split_scheme}_train.csv') - test_data_df = pd.read_csv(self.root / f'{split_scheme}_test.csv') - val_data_df = pd.DataFrame(columns=["image","labels","group"]) + val_data_df = pd.read_csv(self.root / f'official_val.csv') + test_data_df = pd.read_csv(self.root / f'in-dist_test.csv') + + elif split_scheme == "benchmark_in-dist": + train_data_df = pd.read_csv(self.root / f'in-dist_train.csv') + val_data_df = pd.read_csv(self.root / f'official_val.csv') + test_data_df = pd.read_csv(self.root / f'in-dist_test.csv') self._image_array = [] From 94536428301c994b19e8c435d4c34e1c687a5af5 Mon Sep 17 00:00:00 2001 From: aikanor Date: Thu, 1 Apr 2021 17:15:42 -0700 Subject: [PATCH 125/244] add DNase normalization --- dataset_preprocessing/encode-tfbs/README.md | 8 +- .../encode-tfbs/prep_accessibility.py | 135 +++++++++++++++++- .../encode-tfbs/prep_metadata_labels.py | 2 +- .../encode-tfbs/write_label_bigwig.py | 93 ------------ examples/models/CNN_genome.py | 14 +- examples/models/initializer.py | 4 +- wilds/datasets/encodetfbs_dataset.py | 82 ++++++++++- 7 files changed, 221 insertions(+), 117 deletions(-) delete mode 100644 dataset_preprocessing/encode-tfbs/write_label_bigwig.py diff --git a/dataset_preprocessing/encode-tfbs/README.md b/dataset_preprocessing/encode-tfbs/README.md index 7ecf1135..b4f806ec 100644 --- a/dataset_preprocessing/encode-tfbs/README.md +++ b/dataset_preprocessing/encode-tfbs/README.md @@ -12,8 +12,8 @@ 3. Download the DNase accessibility data. This consists of whole-genome DNase files in bigwig format from https://guanfiles.dcmb.med.umich.edu/Leopard/dnase_bigwig/. 4. Download the labels from the challenge into a label directory created for this purpose: - - The training labels from https://www.synapse.org/#!Synapse:syn7413983 for the relevant transcription factor (e.g. https://www.synapse.org/#!Synapse:syn7415202 for the TF MAX). - - The validation labels from https://www.synapse.org/#!Synapse:syn8441154 for the relevant transcription factor (e.g. https://www.synapse.org/#!Synapse:syn8442103 for the TF MAX). - - (Optional) The validation labels for the challenge's evaluation cell type from https://www.synapse.org/#!Synapse:syn8442975 for the relevant transcription factor (generally primary liver cells, e.g. https://www.synapse.org/#!Synapse:syn8443021 for the TF MAX). + - The training labels from https://www.synapse.org/#!Synapse:syn7413983 for the relevant transcription factor ( https://www.synapse.org/#!Synapse:syn7415202 for the TF MAX). + - The validation labels from https://www.synapse.org/#!Synapse:syn8441154 for the relevant transcription factor ( https://www.synapse.org/#!Synapse:syn8442103 for the TF MAX). + - (Optional) The validation labels for the challenge's evaluation cell type (liver) from https://www.synapse.org/#!Synapse:syn8442975 for the relevant transcription factor ( https://www.synapse.org/#!Synapse:syn8443021 for the TF MAX). -5. Run `write_label_bigwig.py` +5. Run `prep_metadata_labels.py`. diff --git a/dataset_preprocessing/encode-tfbs/prep_accessibility.py b/dataset_preprocessing/encode-tfbs/prep_accessibility.py index 141981c0..86644f6c 100644 --- a/dataset_preprocessing/encode-tfbs/prep_accessibility.py +++ b/dataset_preprocessing/encode-tfbs/prep_accessibility.py @@ -1,9 +1,132 @@ +# Adapted from https://github.com/GuanLab/Leopard/blob/master/data/quantile_normalize_bigwig.py + import argparse, time import numpy as np import pyBigWig -# Human chromosome names -chr_IDs = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX'] +# Human chromosomes in hg19 +chrom_sizes = {'chr1': 249250621, 'chr10': 135534747, 'chr11': 135006516, 'chr12': 133851895, 'chr13': 115169878, 'chr14': 107349540, 'chr15': 102531392, 'chr16': 90354753, 'chr17': 81195210, 'chr18': 78077248, 'chr19': 59128983, 'chr2': 243199373, 'chr20': 63025520, 'chr21': 48129895, 'chr22': 51304566, 'chr3': 198022430, 'chr4': 191154276, 'chr5': 180915260, 'chr6': 171115067, 'chr7': 159138663, 'chr8': 146364022, 'chr9': 141213431, 'chrX': 155270560} + + +def qn_sample_to_array( + input_celltypes, + subsampling_ratio=1000, + data_pfx = '/users/abalsubr/wilds/examples/data/encode-tfbs_v1.0/' +): + itime = time.time() + # chromosome-specific subsampling seeds + chr_to_seed = {} + i = 0 + for the_chr in chrom_sizes: + chr_to_seed[the_chr] = i + i += 1 + + # subsampling; multiple replicates are added + sample_len = np.ceil(np.array(list(chrom_sizes.values()))/subsampling_ratio).astype(int) + sample = np.zeros(sum(sample_len)) + start = 0 + j = 0 + for the_chr in chrom_sizes: + np.random.seed(chr_to_seed[the_chr]) + for ct in input_celltypes: + path = data_pfx + 'DNASE.{}.fc.signal.bigwig'.format(ct) + bw = pyBigWig.open(path) + signal = np.nan_to_num(np.array(bw.values(the_chr, 0, chrom_sizes[the_chr]))) + index = np.random.randint(0, len(signal), sample_len[j]) + sample[start:(start+sample_len[j])] += (1.0/len(input_celltypes))*signal[index] + start += sample_len[j] + j += 1 + print(the_chr, ct, time.time() - itime) + + if np.any(np.isnan(sample)): + print('wtf! sample contains nan!') + sample.sort() + np.save(data_pfx + "qn.{}.npy".format('.'.join(input_celltypes)), sample) + + +# quantile normalization via numpy inter/extra-polation +def anchor(input_data, sample, ref): # input 1d array + sample.sort() + ref.sort() + # 0. create the mapping function + index = np.array(np.where(np.diff(sample) != 0)) + 1 + index = index.flatten() + x = np.concatenate((np.zeros(1), sample[index])) # domain + y = np.zeros(len(x)) # codomain + for i in np.arange(0,len(index)-1, 1): + start = index[i] + end = index[i+1] + y[i+1] = np.mean(ref[start:end]) + i += 1 + start = index[i] + end = len(ref) + y[i+1] = np.mean(ref[start:end]) + # 1. interpolate + output = np.interp(input_data, x, y) + # 2. extrapolate + degree = 1 # degree of the fitting polynomial + num = 10 # number of positions for extrapolate + f1 = np.poly1d(np.polyfit(sample[-num:],ref[-num:],degree)) +# f2=np.poly1d(np.polyfit(sample[:num],ref[:num],degree)) + output[input_data > sample[-1]] = f1(input_data[input_data > sample[-1]]) +# output[input_data {}'.format(_unsorted_dir, _sorted_dir)) - - mdf_posamb = pd.read_csv( - _sorted_dir, - sep='\t', header=None, index_col=None, names=['chr', 'start', 'stop', 'y', 'celltype'] - ) - chromsizes_list = [(k, v) for k, v in chrom_sizes.items()] - for ct in _all_celltypes: - ct_labels_bw_path = _data_dir + "labels/MAX/MAX_{}.bigwig".format(ct) - df = mdf_posamb[mdf_posamb['celltype'] == ct] - bw = pyBigWig.open(ct_labels_bw_path, "w") - bw.addHeader(chromsizes_list) - bw.addEntries(list(df['chr']), list(df['start']), ends=list(df['start']+50), values=list(df['y'])) - print(ct, time.time() - itime) - bw.close() diff --git a/examples/models/CNN_genome.py b/examples/models/CNN_genome.py index 484632a0..b637e762 100644 --- a/examples/models/CNN_genome.py +++ b/examples/models/CNN_genome.py @@ -60,7 +60,7 @@ def __init__(self, num_tasks=16, n_channels_in=5): def forward(self, x): # input_size = 12800 # input_channels = 5 - conv1 = self.dconv_down1(x) # Out: (input_size) x 15 + conv1 = self.dconv_down1(x) # Output size: (input_size) x 15 x = self.maxpool(conv1) # (input_size / 2) x 15 conv2 = self.dconv_down2(x) # (input_size / 2) x 22 @@ -101,10 +101,14 @@ def forward(self, x): x = self.dconv_up1(x) # (input_size) x 15 - x = self.conv_last(x) + x = self.conv_last(x) # (input_size/50 - 3) x 1 + x = torch.squeeze(x) + # Default input_size == 12800: x has size N x 1 x 253 at this point. if self.d_out is None: - return x.shape[-1] # Default: 253 values + self.d_out = x.shape[-1] + out = x else: # middle 128 values - out = x[:, :, 64:192] - return torch.squeeze(out) + out = x[:, 64:192] + + return out diff --git a/examples/models/initializer.py b/examples/models/initializer.py index 4effde23..86033ffe 100644 --- a/examples/models/initializer.py +++ b/examples/models/initializer.py @@ -75,11 +75,11 @@ def initialize_model(config, d_out, is_featurizer=False): model = nn.Linear(out_features=d_out, **config.model_kwargs) elif config.model == 'unet-seq': if is_featurizer: - featurizer = UNet(out_features=None, **config.model_kwargs) + featurizer = UNet(num_tasks=None, **config.model_kwargs) classifier = nn.Linear(featurizer.d_out, d_out) model = (featurizer, classifier) else: - model = UNet(out_features=d_out, **config.model_kwargs) + model = UNet(num_tasks=d_out, **config.model_kwargs) else: raise ValueError(f'Model: {config.model} not recognized.') return model diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index f6fe29cb..a09303e8 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -25,13 +25,13 @@ class EncodeTFBSDataset(WILDSDataset): Each sequence is annotated with the celltype of origin (a string) and the chromosome of origin (a string). Website: - https://www.synapse.org/#!Synapse:syn6131484 . This is the website for the challenge; the data can be downloaded from here into the meta + https://www.synapse.org/#!Synapse:syn6131484 . This is the website for the challenge; the data can be downloaded from here as per the instructions in data """ _dataset_name = 'encode-tfbs' _versions_dict = { '1.0': { - 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x7efd626149d648f699d9e686d7aa81a9/contents/blob/', + 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x6ba433262726430eb91449a1edb2fed0/contents/blob/', 'compressed_size': None}} def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'): @@ -183,7 +183,7 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' 'val': 'Validation (OOD)', 'test': 'Test', } - # Add new split scheme specifying custom test and val celltypes in the format test..val.. + # Add new split scheme specifying custom test and val celltypes in the format val..test., e.g. 'official' is 'val.A549.test.GM12878' elif '.' in self._split_scheme: all_celltypes = train_celltypes + val_celltype + test_celltype in_val_ct = self._split_scheme.split('.')[1] @@ -242,7 +242,7 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' # For the OOD splits (val and test), we subsample by a factor of 3 # For the id_val split if it exists, we subsample by a factor of 15 for subsample_seed, (split, subsample_factor) in enumerate( - [('val', 3), ('test', 3), ('id_val', 15)]): + [('val', 3), ('test', 3), ('id_val', 3*)]): if split not in self._split_dict: continue split_mask = (self._split_array == self._split_dict[split]) split_idxs = np.arange(len(self._split_array))[split_mask] @@ -272,8 +272,24 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' # Set up file handles for DNase features self._dnase_allcelltypes = {} for ct in self._all_celltypes: - dnase_bw_path = os.path.join(self._data_dir, 'DNase/{}.bigwig'.format(ct)) + """ + if 'challenge' in self._split_scheme: + dnase_bw_path = os.path.join(self._data_dir, 'DNASE.{}.fc.signal.bigwig'.format(ct)) + else: + dnase_bw_path = os.path.join(self._data_dir, 'DNase/{}.bigwig'.format(ct)) + """ + dnase_bw_path = os.path.join(self._data_dir, 'DNASE.{}.fc.signal.bigwig'.format(ct)) self._dnase_allcelltypes[ct] = pyBigWig.open(dnase_bw_path) + + # Load subsampled DNase arrays for normalization purposes + self._dnase_qnorm_arrays = {} + for ct in self._all_celltypes: + qnorm_arr_path = os.path.join(self._data_dir, 'qn.{}.npy'.format(ct)) + self._dnase_qnorm_arrays[ct] = np.load(qnorm_arr_path) + self._norm_ref_distr = np.zeros(len(self._dnase_qnorm_arrays[ct])) + train_cts = splits['train']['celltypes'] + for ct in train_cts: + self._norm_ref_distr += (1.0/len(train_cts))*self._dnase_qnorm_arrays[ct] # Set up metadata fields, map, array self._metadata_fields = ['chr', 'celltype'] @@ -296,6 +312,59 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' super().__init__(root_dir, download, split_scheme) + # quantile normalization via numpy inter/extra-polation + def anchor(input_data, sample, ref): # input 1d array + sample.sort() + ref.sort() + # 0. create the mapping function + index = np.array(np.where(np.diff(sample) != 0)) + 1 + index = index.flatten() + x = np.concatenate((np.zeros(1), sample[index])) # domain + y = np.zeros(len(x)) # codomain + for i in np.arange(0,len(index)-1, 1): + start = index[i] + end = index[i+1] + y[i+1] = np.mean(ref[start:end]) + i += 1 + start = index[i] + end = len(ref) + y[i+1] = np.mean(ref[start:end]) + # 1. interpolate + output = np.interp(input_data, x, y) + # 2. extrapolate + degree = 1 # degree of the fitting polynomial + num = 10 # number of positions for extrapolate + f1 = np.poly1d(np.polyfit(sample[-num:],ref[-num:],degree)) + # f2=np.poly1d(np.polyfit(sample[:num],ref[:num],degree)) + output[input_data > sample[-1]] = f1(input_data[input_data > sample[-1]]) + # output[input_data Date: Thu, 1 Apr 2021 18:13:04 -0700 Subject: [PATCH 126/244] fixes in get_input --- .../encode-tfbs/prep_accessibility.py | 2 +- wilds/datasets/encodetfbs_dataset.py | 76 ++++++++++--------- 2 files changed, 42 insertions(+), 36 deletions(-) diff --git a/dataset_preprocessing/encode-tfbs/prep_accessibility.py b/dataset_preprocessing/encode-tfbs/prep_accessibility.py index 86644f6c..1b15f30e 100644 --- a/dataset_preprocessing/encode-tfbs/prep_accessibility.py +++ b/dataset_preprocessing/encode-tfbs/prep_accessibility.py @@ -90,7 +90,7 @@ def wrap_anchor( ends = np.concatenate(([starts[0]],ends)) starts = np.concatenate(([0],starts)) vals = np.concatenate(([0],vals)) - if ends[-1] != chrom_sizes[the_chr]: + if ends[-1] != len_signal: starts = np.concatenate((starts,[ends[-1]])) ends = np.concatenate((ends,[len_signal])) vals = np.concatenate((vals,[0])) diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index a09303e8..b130fb7f 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -8,6 +8,36 @@ from wilds.common.grouper import CombinatorialGrouper from wilds.common.metrics.all_metrics import MultiTaskAveragePrecision + +# quantile normalization via numpy inter/extra-polation +def anchor(input_data, sample, ref): # input 1d array + sample.sort() + ref.sort() + # 0. create the mapping function + index = np.array(np.where(np.diff(sample) != 0)) + 1 + index = index.flatten() + x = np.concatenate((np.zeros(1), sample[index])) # domain + y = np.zeros(len(x)) # codomain + for i in np.arange(0,len(index)-1, 1): + start = index[i] + end = index[i+1] + y[i+1] = np.mean(ref[start:end]) + i += 1 + start = index[i] + end = len(ref) + y[i+1] = np.mean(ref[start:end]) + # 1. interpolate + output = np.interp(input_data, x, y) + # 2. extrapolate + degree = 1 # degree of the fitting polynomial + num = 10 # number of positions for extrapolate + f1 = np.poly1d(np.polyfit(sample[-num:],ref[-num:],degree)) +# f2=np.poly1d(np.polyfit(sample[:num],ref[:num],degree)) + output[input_data > sample[-1]] = f1(input_data[input_data > sample[-1]]) +# output[input_data Date: Thu, 1 Apr 2021 18:56:26 -0700 Subject: [PATCH 127/244] integration fixes --- .../encode-tfbs/prep_accessibility.py | 29 ++++++++++++------- wilds/datasets/encodetfbs_dataset.py | 17 +++++++---- 2 files changed, 31 insertions(+), 15 deletions(-) diff --git a/dataset_preprocessing/encode-tfbs/prep_accessibility.py b/dataset_preprocessing/encode-tfbs/prep_accessibility.py index 1b15f30e..2c36a124 100644 --- a/dataset_preprocessing/encode-tfbs/prep_accessibility.py +++ b/dataset_preprocessing/encode-tfbs/prep_accessibility.py @@ -75,8 +75,8 @@ def anchor(input_data, sample, ref): # input 1d array def wrap_anchor( signal, - sample, ref, - len_signal + sample, + ref ): ## 1.format as bigwig first x = signal @@ -90,9 +90,9 @@ def wrap_anchor( ends = np.concatenate(([starts[0]],ends)) starts = np.concatenate(([0],starts)) vals = np.concatenate(([0],vals)) - if ends[-1] != len_signal: + if ends[-1] != len(signal): starts = np.concatenate((starts,[ends[-1]])) - ends = np.concatenate((ends,[len_signal])) + ends = np.concatenate((ends,[len(signal)])) vals = np.concatenate((vals,[0])) ## 2.then quantile normalization @@ -103,12 +103,14 @@ def wrap_anchor( def dnase_normalize( input_bw_celltype, sample_celltype, - ref_celltype, + ref_celltypes, data_pfx = '/users/abalsubr/wilds/examples/data/encode-tfbs_v1.0/' ): itime = time.time() sample = np.load(data_pfx + "qn.{}.npy".format(sample_celltype)) - ref = np.load(data_pfx + "qn.{}.npy".format(ref_celltype)) + ref = np.zeros(len(sample)) + for ct in ref_celltypes: + ref += (1.0/len(ref_celltypes))*np.load(data_pfx + "qn.{}.npy".format(ct)) chromsizes_list = [(k, v) for k, v in chrom_sizes.items()] bw_output = pyBigWig.open(data_pfx + 'DNase.{}.norm.bigwig'.format(input_bw_celltype), 'w') @@ -120,7 +122,7 @@ def dnase_normalize( bw = pyBigWig.open(data_pfx + 'DNASE.{}.fc.signal.bigwig'.format(input_bw_celltype)) signal += np.nan_to_num(np.array(bw.values(the_chr, 0, chrom_sizes[the_chr]))) bw.close() - vals_anchored, starts, ends = wrap_anchor(signal, sample, ref, chrom_sizes[the_chr]) + vals_anchored, starts, ends = wrap_anchor(signal, sample, ref) # write normalized dnase file. chroms = np.array([the_chr] * len(vals_anchored)) bw_output.addEntries(chroms, starts, ends=ends, values=vals_anchored) @@ -154,11 +156,18 @@ def generate_accessibility_archives(input_dir='dnase_bigwigs', output_dir='codal if __name__ == '__main__': + ch_train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'K562', 'A549', 'GM12878'] + ch_val_celltype = ['HepG2'] + ch_test_celltype = ['liver'] + ref_celltypes = ch_train_celltypes + ch_val_celltype + all_celltypes = ref_celltypes + ch_test_celltype + for ct in all_celltypes: + qn_sample_to_array([ct]) + for ct in all_celltypes: + dnase_normalize(ct, ct, ref_celltypes) # parser = argparse.ArgumentParser() # parser.add_argument('--input_dir', required=True) # parser.add_argument('--output_dir', required=True) # args = parser.parse_args() - generate_accessibility_archives( - input_dir=args.input_dir, - output_dir=args.output_dir) \ No newline at end of file + # generate_accessibility_archives(input_dir=args.input_dir, output_dir=args.output_dir) \ No newline at end of file diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index b130fb7f..af0785d2 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -309,6 +309,7 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' dnase_bw_path = os.path.join(self._data_dir, 'DNase/{}.bigwig'.format(ct)) """ dnase_bw_path = os.path.join(self._data_dir, 'DNASE.{}.fc.signal.bigwig'.format(ct)) + # dnase_bw_path = os.path.join(self._data_dir, 'DNase.{}.norm.bigwig'.format(ct)) self._dnase_allcelltypes[ct] = pyBigWig.open(dnase_bw_path) # Load subsampled DNase arrays for normalization purposes @@ -317,9 +318,11 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' qnorm_arr_path = os.path.join(self._data_dir, 'qn.{}.npy'.format(ct)) self._dnase_qnorm_arrays[ct] = np.load(qnorm_arr_path) self._norm_ref_distr = np.zeros(len(self._dnase_qnorm_arrays[ct])) - train_cts = splits['train']['celltypes'] - for ct in train_cts: - self._norm_ref_distr += (1.0/len(train_cts))*self._dnase_qnorm_arrays[ct] + test_cts = splits['test']['celltypes'] + num_to_avg = len(self._all_celltypes) - len(test_cts) + for ct in self._all_celltypes: + if ct not in test_cts: + self._norm_ref_distr += (1.0/num_to_avg)*self._dnase_qnorm_arrays[ct] # Set up metadata fields, map, array self._metadata_fields = ['chr', 'celltype'] @@ -369,7 +372,7 @@ def norm_signal( vals_arr = np.zeros(ends[-1]) for i in range(len(starts)): vals_arr[starts[i]:ends[i]] = vals_anchored[i] - return vals_arr + return vals_arr.astype(float) def get_input(self, idx, window_size=12800): """ @@ -392,7 +395,11 @@ def get_input(self, idx, window_size=12800): print("error", chrom, interval_start, interval_end) assert(np.isnan(seq_this).sum() == 0) assert(np.isnan(dnase_this).sum() == 0) - dnase_this = self.norm_signal(dnase_this, this_metadata['celltype']) + # print(dnase_this.dtype) +# try: +# dnase_this = self.norm_signal(dnase_this, this_metadata['celltype']) +# except RuntimeError: +# print(dnase_this.dtype) # print('a', dnase_this.shape, starts, ends, starts.shape, ends.shape) return torch.tensor(np.column_stack( [seq_this, From 5770ce234a89ec14c87a387b16423d76a5019693 Mon Sep 17 00:00:00 2001 From: aikanor Date: Thu, 1 Apr 2021 21:05:14 -0700 Subject: [PATCH 128/244] working version, with pre-normalized codalab --- wilds/datasets/encodetfbs_dataset.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index af0785d2..990ed306 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -49,7 +49,7 @@ class EncodeTFBSDataset(WILDSDataset): Label (y): y is a 128-bit vector, with each element y_i indicating the binding status of a 200bp window. It is 1 if this 200bp region is bound by the transcription factor, and 0 otherwise, for i = 0,1,...,127. - Suppose the input window x starts at coordinate sc, extending until coordinate (sc+12800). Then y_i is the label of the window starting at coordinate (sc+3200)+(50*i). + Concretely, suppose the input window x starts at coordinate sc, extending until coordinate (sc+12800). Then y_i is the label of the window starting at coordinate (sc+3200)+(50*i). Metadata: Each sequence is annotated with the celltype of origin (a string) and the chromosome of origin (a string). @@ -61,7 +61,8 @@ class EncodeTFBSDataset(WILDSDataset): _dataset_name = 'encode-tfbs' _versions_dict = { '1.0': { - 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x6ba433262726430eb91449a1edb2fed0/contents/blob/', + # 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x6ba433262726430eb91449a1edb2fed0/contents/blob/', + 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x94505beac8794afbb5b2dbae747ec29f/contents/blob/', 'compressed_size': None}} def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'): @@ -308,8 +309,8 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' else: dnase_bw_path = os.path.join(self._data_dir, 'DNase/{}.bigwig'.format(ct)) """ - dnase_bw_path = os.path.join(self._data_dir, 'DNASE.{}.fc.signal.bigwig'.format(ct)) - # dnase_bw_path = os.path.join(self._data_dir, 'DNase.{}.norm.bigwig'.format(ct)) + # dnase_bw_path = os.path.join(self._data_dir, 'DNASE.{}.fc.signal.bigwig'.format(ct)) + dnase_bw_path = os.path.join(self._data_dir, 'DNase.{}.norm.bigwig'.format(ct)) self._dnase_allcelltypes[ct] = pyBigWig.open(dnase_bw_path) # Load subsampled DNase arrays for normalization purposes From c499ffa7a684d0a1e872e1077f4b894493c08289 Mon Sep 17 00:00:00 2001 From: Etienne DAVID Date: Fri, 2 Apr 2021 20:38:11 +0200 Subject: [PATCH 129/244] correct metric --- .gitignore | 5 +++-- examples/run_expt.py | 2 ++ setup.py | 1 + wilds/common/data_loaders.py | 2 ++ wilds/common/metrics/all_metrics.py | 10 ++++++---- 5 files changed, 14 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index 388819f7..02035664 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ wilds.egg-info data logs test_faster -paper +paper* .vscode -*sh \ No newline at end of file +*sh +*ipynb \ No newline at end of file diff --git a/examples/run_expt.py b/examples/run_expt.py index e8039113..34edbd66 100644 --- a/examples/run_expt.py +++ b/examples/run_expt.py @@ -184,6 +184,7 @@ def main(): transform=transform) if split == 'train': + datasets[split]['loader'] = get_train_loader( loader=config.train_loader, dataset=datasets[split]['dataset'], @@ -193,6 +194,7 @@ def main(): distinct_groups=config.distinct_groups, n_groups_per_batch=config.n_groups_per_batch, **config.loader_kwargs) + else: datasets[split]['loader'] = get_eval_loader( loader=config.eval_loader, diff --git a/setup.py b/setup.py index 9cd1f596..1fb6a2e5 100644 --- a/setup.py +++ b/setup.py @@ -30,6 +30,7 @@ 'tqdm>=4.53.0', 'outdated>=0.2.0', 'pytz>=2020.4', + 'torchvision==0.8.2' ], license='MIT', packages=setuptools.find_packages(exclude=['dataset_preprocessing', 'examples', 'examples.models', 'examples.models.bert']), diff --git a/wilds/common/data_loaders.py b/wilds/common/data_loaders.py index b806832a..12e7e0c3 100644 --- a/wilds/common/data_loaders.py +++ b/wilds/common/data_loaders.py @@ -63,6 +63,7 @@ def get_train_loader(loader, dataset, batch_size, raise ValueError(f'n_groups_per_batch was set to {n_groups_per_batch} but there are only {grouper.n_groups} groups specified.') group_ids = grouper.metadata_to_group(dataset.metadata_array) + batch_sampler = GroupSampler( group_ids=group_ids, batch_size=batch_size, @@ -70,6 +71,7 @@ def get_train_loader(loader, dataset, batch_size, uniform_over_groups=uniform_over_groups, distinct_groups=distinct_groups) + return DataLoader(dataset, shuffle=None, sampler=None, diff --git a/wilds/common/metrics/all_metrics.py b/wilds/common/metrics/all_metrics.py index 84e592d3..16172033 100644 --- a/wilds/common/metrics/all_metrics.py +++ b/wilds/common/metrics/all_metrics.py @@ -190,13 +190,13 @@ def _compute_element_wise(self, y_pred ,y_true ): #target_scores = F.softmax(target_logits, dim=1)[..., 0] pred_boxes = target_boxes[target_scores > self.score_threshold] - det_accuracy = self._accuracy(src_boxes["boxes"],pred_boxes,iou_threshold=self.iou_threshold) + det_accuracy = torch.mean(torch.stack([ self._accuracy(src_boxes["boxes"],pred_boxes,iou_thr) for iou_thr in np.arange(0.5,0.76,0.05)])) batch_results.append(det_accuracy) return torch.tensor(batch_results) - def _accuracy(self, src_boxes,pred_boxes , iou_threshold = 1.): + def _accuracy(self, src_boxes,pred_boxes , iou_threshold): total_gt = len(src_boxes) total_pred = len(pred_boxes) @@ -206,8 +206,8 @@ def _accuracy(self, src_boxes,pred_boxes , iou_threshold = 1.): # Define the matcher and distance matrix based on iou matcher = Matcher(iou_threshold,iou_threshold,allow_low_quality_matches=False) - src_boxes = box_convert(src_boxes , "cxcywh" ,"xyxy") - pred_boxes = box_convert(pred_boxes , "cxcywh" ,"xyxy") + #src_boxes = box_convert(src_boxes , "cxcywh" ,"xyxy") + #pred_boxes = box_convert(pred_boxes , "cxcywh" ,"xyxy") match_quality_matrix = box_iou(src_boxes,pred_boxes) @@ -220,6 +220,8 @@ def _accuracy(self, src_boxes,pred_boxes , iou_threshold = 1.): #in Matcher, a pred element can be matched only twice false_positive = torch.count_nonzero(results == -1) + ( len(matched_elements) - len(matched_elements.unique())) false_negative = total_gt - true_positive + acc= true_positive / ( true_positive + false_positive + false_negative ) + return true_positive / ( true_positive + false_positive + false_negative ) From f72e0dc68c190106293ffb44a16101ba8f738e50 Mon Sep 17 00:00:00 2001 From: Tony Lee Date: Sat, 3 Apr 2021 03:13:56 -0700 Subject: [PATCH 130/244] Add documentation for evaluation script --- README.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/README.md b/README.md index 00955e91..7328e7c3 100644 --- a/README.md +++ b/README.md @@ -113,6 +113,17 @@ While the `camelyon17` dataset is small and fast to train on, we advise against The image datasets (`iwildcam`, `camelyon17`, `fmow`, and `poverty`) tend to have high disk I/O usage. If training time is much slower for you than the approximate times listed above, consider checking if I/O is a bottleneck (e.g., by moving to a local disk if you are using a network drive, or by increasing the number of data loader workers). To speed up training, you could also disable evaluation at each epoch or for all splits by toggling `--evaluate_all_splits` and related arguments. +We also provide an evaluation script to evaluate prediction CSV files. In order to evaluate your predictions, run: + +```bash +python examples/evaluate.py --root-dir +``` + +where `` is the path to your predictions directory, `` is where the results JSON will be +outputted and `` is the dataset directory. The predictions directory should have a subdirectory for each dataset +(e.g. `iwildcam`) containing prediction CSV files to evaluate. The evaluation script will skip over any datasets that has +missing prediction files. Any dataset not in `` will be downloaded to ``. + We have an [executable version](https://wilds.stanford.edu/codalab) of our paper on CodaLab that contains the exact commands, code, and data for the experiments reported in our paper, which rely on these scripts. Trained model weights for all datasets can also be found there. From 904559cd85931a2cdc2ccfa50c5c1acc9173d300 Mon Sep 17 00:00:00 2001 From: kohpangwei Date: Sun, 4 Apr 2021 23:20:40 -0700 Subject: [PATCH 131/244] Update README.md --- README.md | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 7328e7c3..b5d879f2 100644 --- a/README.md +++ b/README.md @@ -113,7 +113,7 @@ While the `camelyon17` dataset is small and fast to train on, we advise against The image datasets (`iwildcam`, `camelyon17`, `fmow`, and `poverty`) tend to have high disk I/O usage. If training time is much slower for you than the approximate times listed above, consider checking if I/O is a bottleneck (e.g., by moving to a local disk if you are using a network drive, or by increasing the number of data loader workers). To speed up training, you could also disable evaluation at each epoch or for all splits by toggling `--evaluate_all_splits` and related arguments. -We also provide an evaluation script to evaluate prediction CSV files. In order to evaluate your predictions, run: +We also provide an evaluation script that aggregates prediction CSV files for different replicates and reports on their combined evaluation. To use this, run: ```bash python examples/evaluate.py --root-dir @@ -121,8 +121,7 @@ python examples/evaluate.py --root-dir where `` is the path to your predictions directory, `` is where the results JSON will be outputted and `` is the dataset directory. The predictions directory should have a subdirectory for each dataset -(e.g. `iwildcam`) containing prediction CSV files to evaluate. The evaluation script will skip over any datasets that has -missing prediction files. Any dataset not in `` will be downloaded to ``. +(e.g. `iwildcam`) containing prediction CSV files to evaluate; see our [submission guidelines](https://wilds.stanford.edu/submit/) for the format. The evaluation script will skip over any datasets that has missing prediction files. Any dataset not in `` will be downloaded to ``. We have an [executable version](https://wilds.stanford.edu/codalab) of our paper on CodaLab that contains the exact commands, code, and data for the experiments reported in our paper, which rely on these scripts. Trained model weights for all datasets can also be found there. From 1636590b9b113be7e0de2b77af77320bc3021f12 Mon Sep 17 00:00:00 2001 From: kohpangwei Date: Mon, 5 Apr 2021 21:03:19 -0700 Subject: [PATCH 132/244] Update metric.py --- wilds/common/metrics/metric.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/wilds/common/metrics/metric.py b/wilds/common/metrics/metric.py index 8f5ee416..89582577 100644 --- a/wilds/common/metrics/metric.py +++ b/wilds/common/metrics/metric.py @@ -131,8 +131,8 @@ def _compute_group_wise(self, y_pred, y_true, g, n_groups): else: group_metrics.append( self._compute( - get_subset_from_mask(y_pred, g == group_idx), - get_subset_from_mask(y_true, g == group_idx))) + y_pred[g == group_idx], + y_true[g == group_idx])) group_metrics = torch.stack(group_metrics) worst_group_metric = self.worst(group_metrics[group_counts>0]) From 96205552c6dea4172034e4f9992fd07fa3c49f8e Mon Sep 17 00:00:00 2001 From: Berton Earnshaw Date: Tue, 6 Apr 2021 08:34:04 -0600 Subject: [PATCH 133/244] Add correct target_resolution to dataset config --- examples/configs/datasets.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index 4f5f74b7..cf2e6e48 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -288,6 +288,7 @@ 'model_kwargs': {'pretrained': True}, 'train_transform': 'image_base', 'eval_transform': 'image_base', + 'target_resolution': (256, 256), 'loss_function': 'cross_entropy', 'groupby_fields': ['experiment'], 'val_metric': 'acc_avg', From 232a0e3444ed3d071c93da6544d70925cef08061 Mon Sep 17 00:00:00 2001 From: Berton Earnshaw Date: Tue, 6 Apr 2021 13:42:10 -0600 Subject: [PATCH 134/244] Add validation set and cell_type to metadata --- wilds/datasets/rxrx1_dataset.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/wilds/datasets/rxrx1_dataset.py b/wilds/datasets/rxrx1_dataset.py index 6558d8e7..f14503f2 100644 --- a/wilds/datasets/rxrx1_dataset.py +++ b/wilds/datasets/rxrx1_dataset.py @@ -1,4 +1,3 @@ -from datetime import datetime import os from pathlib import Path @@ -9,7 +8,7 @@ from wilds.datasets.wilds_dataset import WILDSDataset from wilds.common.grouper import CombinatorialGrouper -from wilds.common.metrics.all_metrics import Accuracy, Recall, F1 +from wilds.common.metrics.all_metrics import Accuracy class RxRx1Dataset(WILDSDataset): @@ -72,12 +71,9 @@ def __init__(self, version=None, root_dir='rxrx1-wilds', download=False, df = pd.read_csv(self._data_dir / 'metadata.csv') # Splits - # FIXME: Add validation - self._split_dict = {'train': 0, 'test': 1} - self._split_names = {'train': 'Train', 'test': 'Test'} + self._split_dict = {'train': 0, 'val': 1, 'test': 2} + self._split_names = {'train': 'Train', 'val': 'Validation', 'test': 'Test'} self._split_array = df.dataset.apply(self._split_dict.get).values - # split_dict = {'train': 0, 'test': 1} - # self._split_array = df.dataset.apply(split_dict.get).values # Filenames def create_filepath(row): @@ -97,7 +93,7 @@ def create_filepath(row): # Convert experiment and well from strings to idxs indexed_metadata = {} self._metadata_map = {} - for key in ['experiment', 'well']: + for key in ['cell_type', 'experiment', 'well']: all_values = list(df[key].unique()) value_to_idx_map = {value: idx for idx, value in enumerate(all_values)} value_idxs = [value_to_idx_map[value] for value in df[key].tolist()] @@ -105,18 +101,19 @@ def create_filepath(row): indexed_metadata[key] = value_idxs self._metadata_array = torch.tensor( - np.stack([indexed_metadata['experiment'], + np.stack([indexed_metadata['cell_type'], + indexed_metadata['experiment'], df['plate'].values, indexed_metadata['well'], df['site'].values, self.y_array], axis=1) ) - self._metadata_fields = ['experiment', 'plate', 'well', 'site', 'y'] + self._metadata_fields = ['cell_type', 'experiment', 'plate', 'well', 'site', 'y'] # eval grouper self._eval_grouper = CombinatorialGrouper( dataset=self, - groupby_fields=(['experiment']) + groupby_fields=(['cell_type', 'experiment']) ) super().__init__(root_dir, download, split_scheme) From 812050ed57be6530b311bb19b0d1b8462cec474c Mon Sep 17 00:00:00 2001 From: Berton Earnshaw Date: Tue, 6 Apr 2021 14:03:23 -0600 Subject: [PATCH 135/244] Update download_url --- wilds/datasets/rxrx1_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wilds/datasets/rxrx1_dataset.py b/wilds/datasets/rxrx1_dataset.py index f14503f2..c7468462 100644 --- a/wilds/datasets/rxrx1_dataset.py +++ b/wilds/datasets/rxrx1_dataset.py @@ -52,7 +52,7 @@ class RxRx1Dataset(WILDSDataset): _dataset_name = 'rxrx1' _versions_dict = { '1.0': { - 'download_url': 'https://worksheets.codalab.org/rest/bundles/0xc01e117bb4504f988700408eaeeb16a8/contents/blob/', + 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x6b7a05a3056a434498f0bb1252eb8440/contents/blob/', 'compressed_size': 7_413_123_845} } From afcfc6bf2d614b531b83e4d7099419aebc4f7f2c Mon Sep 17 00:00:00 2001 From: aikanor Date: Tue, 6 Apr 2021 18:41:10 -0700 Subject: [PATCH 136/244] fix DNase normalization --- .../encode-tfbs/prep_accessibility.py | 36 +++++++++++-------- examples/configs/datasets.py | 6 ++-- wilds/datasets/encodetfbs_dataset.py | 34 ++++++++++-------- 3 files changed, 43 insertions(+), 33 deletions(-) diff --git a/dataset_preprocessing/encode-tfbs/prep_accessibility.py b/dataset_preprocessing/encode-tfbs/prep_accessibility.py index 2c36a124..a2be3d91 100644 --- a/dataset_preprocessing/encode-tfbs/prep_accessibility.py +++ b/dataset_preprocessing/encode-tfbs/prep_accessibility.py @@ -10,28 +10,32 @@ def qn_sample_to_array( input_celltypes, + input_chroms=None, subsampling_ratio=1000, data_pfx = '/users/abalsubr/wilds/examples/data/encode-tfbs_v1.0/' ): itime = time.time() + if input_chroms is None: + input_chroms = chrom_sizes.keys() + qn_chrom_sizes = { k: chrom_sizes[k] for k in input_chroms } # chromosome-specific subsampling seeds chr_to_seed = {} i = 0 - for the_chr in chrom_sizes: + for the_chr in qn_chrom_sizes: chr_to_seed[the_chr] = i i += 1 # subsampling; multiple replicates are added - sample_len = np.ceil(np.array(list(chrom_sizes.values()))/subsampling_ratio).astype(int) + sample_len = np.ceil(np.array(list(qn_chrom_sizes.values()))/subsampling_ratio).astype(int) sample = np.zeros(sum(sample_len)) start = 0 j = 0 - for the_chr in chrom_sizes: + for the_chr in qn_chrom_sizes: np.random.seed(chr_to_seed[the_chr]) for ct in input_celltypes: path = data_pfx + 'DNASE.{}.fc.signal.bigwig'.format(ct) bw = pyBigWig.open(path) - signal = np.nan_to_num(np.array(bw.values(the_chr, 0, chrom_sizes[the_chr]))) + signal = np.nan_to_num(np.array(bw.values(the_chr, 0, qn_chrom_sizes[the_chr]))) index = np.random.randint(0, len(signal), sample_len[j]) sample[start:(start+sample_len[j])] += (1.0/len(input_celltypes))*signal[index] start += sample_len[j] @@ -104,6 +108,7 @@ def dnase_normalize( input_bw_celltype, sample_celltype, ref_celltypes, + out_fname = 'norm', data_pfx = '/users/abalsubr/wilds/examples/data/encode-tfbs_v1.0/' ): itime = time.time() @@ -113,7 +118,8 @@ def dnase_normalize( ref += (1.0/len(ref_celltypes))*np.load(data_pfx + "qn.{}.npy".format(ct)) chromsizes_list = [(k, v) for k, v in chrom_sizes.items()] - bw_output = pyBigWig.open(data_pfx + 'DNase.{}.norm.bigwig'.format(input_bw_celltype), 'w') + bw_output = pyBigWig.open(data_pfx + 'DNase.{}.{}.bigwig'.format( + input_bw_celltype, out_fname), 'w') bw_output.addHeader(chromsizes_list) # bw_output.addHeader(list(zip(chr_all , num_bp)), maxZooms=0) # zip two turples @@ -126,7 +132,7 @@ def dnase_normalize( # write normalized dnase file. chroms = np.array([the_chr] * len(vals_anchored)) bw_output.addEntries(chroms, starts, ends=ends, values=vals_anchored) - print(the_chr, time.time() - itime) + print(input_bw_celltype, the_chr, time.time() - itime) bw_output.close() @@ -156,18 +162,18 @@ def generate_accessibility_archives(input_dir='dnase_bigwigs', output_dir='codal if __name__ == '__main__': + train_chroms = ['chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr10', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX'] ch_train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'K562', 'A549', 'GM12878'] ch_val_celltype = ['HepG2'] ch_test_celltype = ['liver'] - ref_celltypes = ch_train_celltypes + ch_val_celltype - all_celltypes = ref_celltypes + ch_test_celltype + ref_celltypes = ch_train_celltypes + all_celltypes = ch_train_celltypes + ch_val_celltype + ch_test_celltype for ct in all_celltypes: - qn_sample_to_array([ct]) + qn_sample_to_array([ct], input_chroms=train_chroms) + + # Create normalized bigwigs for OOD validation split. for ct in all_celltypes: dnase_normalize(ct, ct, ref_celltypes) -# parser = argparse.ArgumentParser() -# parser.add_argument('--input_dir', required=True) -# parser.add_argument('--output_dir', required=True) -# args = parser.parse_args() - - # generate_accessibility_archives(input_dir=args.input_dir, output_dir=args.output_dir) \ No newline at end of file + # Create normalized bigwig for ID validation split. + for ct in ch_test_celltype: + dnase_normalize(ct, ct, ch_test_celltype) diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index 7cc44ab7..cb9f9421 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -119,9 +119,9 @@ 'optimizer': 'Adam', 'scheduler': None, 'batch_size': 128, - 'lr': 1e-4, - 'weight_decay': 1e-2, - 'n_epochs': 1, + 'lr': 1e-5, + 'weight_decay': 1e-4, + 'n_epochs': 10, 'n_groups_per_batch': 2, 'algo_log_metric': 'multitask_binary_accuracy', # 'irm_lambda': 1.0, diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index 990ed306..31a842db 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -61,8 +61,7 @@ class EncodeTFBSDataset(WILDSDataset): _dataset_name = 'encode-tfbs' _versions_dict = { '1.0': { - # 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x6ba433262726430eb91449a1edb2fed0/contents/blob/', - 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x94505beac8794afbb5b2dbae747ec29f/contents/blob/', + 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x370346dfdda44758b75041ca1a5921f4/contents/blob/', 'compressed_size': None}} def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'): @@ -86,7 +85,9 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' # This typically happens at the flanking regions of peaks. # For our purposes, we will ignore these ambiguous labels during training and eval. self.y_array[self.y_array == 0.5] = float('nan') - + + dnase_norm_mode = 'norm' + # Construct splits train_chroms = ['chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr10', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX'] val_chroms = ['chr2', 'chr9', 'chr11'] @@ -152,10 +153,15 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' 'test': 'Test', } # Add challenge splits, assuming 'liver' celltype is in the data. - elif self._split_scheme == 'challenge': - ch_train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'K562', 'A549', 'GM12878'] - ch_val_celltype = ['HepG2'] - ch_test_celltype = ['liver'] + elif self._split_scheme in ['challenge', 'challenge_alt']: + if self._split_scheme == 'challenge': + ch_train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'K562', 'A549', 'GM12878'] + ch_val_celltype = ['HepG2'] + ch_test_celltype = ['liver'] + elif self._split_scheme == 'challenge_alt': + ch_train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'A549', 'GM12878'] + ch_val_celltype = ['K562'] + ch_test_celltype = ['liver'] splits = { 'train': { 'chroms': train_chroms, @@ -187,8 +193,7 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' 'id_val': 'Validation (ID)', } elif self._split_scheme == 'challenge_in-dist': - ch_train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'K562', 'A549', 'GM12878'] - ch_val_celltype = ['HepG2'] + dnase_norm_mode = 'norm_id' ch_test_celltype = ['liver'] splits = { 'train': { @@ -309,8 +314,10 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' else: dnase_bw_path = os.path.join(self._data_dir, 'DNase/{}.bigwig'.format(ct)) """ - # dnase_bw_path = os.path.join(self._data_dir, 'DNASE.{}.fc.signal.bigwig'.format(ct)) - dnase_bw_path = os.path.join(self._data_dir, 'DNase.{}.norm.bigwig'.format(ct)) + dnase_bw_path = os.path.join( + self._data_dir, + 'DNase.{}.{}.bigwig'.format(ct, dnase_norm_mode) + ) self._dnase_allcelltypes[ct] = pyBigWig.open(dnase_bw_path) # Load subsampled DNase arrays for normalization purposes @@ -351,10 +358,8 @@ def norm_signal( signal, sample_celltype ): - ## 1.format as bigwig first x = signal z = np.concatenate(([0],x,[0])) # pad two zeroes - # find boundary starts = np.where(np.diff(z) != 0)[0] ends = starts[1:] starts = starts[:-1] @@ -367,8 +372,7 @@ def norm_signal( starts = np.concatenate((starts,[ends[-1]])) ends = np.concatenate((ends,[len(signal)])) vals = np.concatenate((vals,[0])) - - ## 2.then quantile normalization + vals_anchored = anchor(vals, self._dnase_qnorm_arrays[sample_celltype], self._norm_ref_distr) vals_arr = np.zeros(ends[-1]) for i in range(len(starts)): From 7ecc464736d7e0295408fc3e80df732a199d47a7 Mon Sep 17 00:00:00 2001 From: aikanor Date: Tue, 6 Apr 2021 20:09:26 -0700 Subject: [PATCH 137/244] fix featurizer dim init bug --- dataset_preprocessing/encode-tfbs/prep_accessibility.py | 2 +- examples/models/CNN_genome.py | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/dataset_preprocessing/encode-tfbs/prep_accessibility.py b/dataset_preprocessing/encode-tfbs/prep_accessibility.py index a2be3d91..10e282e5 100644 --- a/dataset_preprocessing/encode-tfbs/prep_accessibility.py +++ b/dataset_preprocessing/encode-tfbs/prep_accessibility.py @@ -176,4 +176,4 @@ def generate_accessibility_archives(input_dir='dnase_bigwigs', output_dir='codal dnase_normalize(ct, ct, ref_celltypes) # Create normalized bigwig for ID validation split. for ct in ch_test_celltype: - dnase_normalize(ct, ct, ch_test_celltype) + dnase_normalize(ct, ct, ch_test_celltype, out_fname = 'norm_id') diff --git a/examples/models/CNN_genome.py b/examples/models/CNN_genome.py index b637e762..a10404d6 100644 --- a/examples/models/CNN_genome.py +++ b/examples/models/CNN_genome.py @@ -28,8 +28,6 @@ class UNet(nn.Module): def __init__(self, num_tasks=16, n_channels_in=5): super().__init__() - self.d_out = num_tasks - self.dconv_down1 = double_conv(n_channels_in, 15) self.dconv_down2 = double_conv(15, 22) self.dconv_down3 = double_conv(22, 33) @@ -55,6 +53,7 @@ def __init__(self, num_tasks=16, n_channels_in=5): self.upsamp_1 = nn.ConvTranspose1d(15, 15, 2, stride=2) self.conv_last = nn.Conv1d(15, 1, 200, stride=50, padding=0) + self.d_out = num_tasks if num_tasks is not None else 253 def forward(self, x): @@ -105,8 +104,7 @@ def forward(self, x): x = torch.squeeze(x) # Default input_size == 12800: x has size N x 1 x 253 at this point. - if self.d_out is None: - self.d_out = x.shape[-1] + if self.d_out == 253: out = x else: # middle 128 values out = x[:, 64:192] From 60bd3247d4d9ffe7a3d8ea351ad951b81f8d0471 Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Tue, 6 Apr 2021 21:18:36 -0700 Subject: [PATCH 138/244] move_to fix for molpcba --- examples/utils.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/examples/utils.py b/examples/utils.py index c554c38e..67a9358c 100644 --- a/examples/utils.py +++ b/examples/utils.py @@ -213,18 +213,17 @@ def get_model_prefix(dataset, config): f"{dataset_name}_{replicate_str}_") return prefix -# Adapted from https://discuss.pytorch.org/t/pytorch-tensor-to-device-for-a-list-of-dict/66283 def move_to(obj, device): - if torch.is_tensor(obj): - return obj.to(device) - elif isinstance(obj, dict): + if isinstance(obj, dict): return {k: move_to(v, device) for k, v in obj.items()} elif isinstance(obj, list): return [move_to(v, device) for v in obj] elif isinstance(obj, float) or isinstance(obj, int): return obj else: - raise TypeError("Invalid type for move_to") + # Assume obj is a Tensor or other type + # (like Batch, for MolPCBA) that supports .to(device) + return obj.to(device) def detach_and_clone(obj): if torch.is_tensor(obj): From 15b4ef7457414e6d84efc078bb66e77a4c0f9375 Mon Sep 17 00:00:00 2001 From: aikanor Date: Wed, 7 Apr 2021 00:02:51 -0700 Subject: [PATCH 139/244] refactor preprocessing --- dataset_preprocessing/encode-tfbs/prep_accessibility.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/dataset_preprocessing/encode-tfbs/prep_accessibility.py b/dataset_preprocessing/encode-tfbs/prep_accessibility.py index 10e282e5..2d8ca789 100644 --- a/dataset_preprocessing/encode-tfbs/prep_accessibility.py +++ b/dataset_preprocessing/encode-tfbs/prep_accessibility.py @@ -106,13 +106,12 @@ def wrap_anchor( def dnase_normalize( input_bw_celltype, - sample_celltype, ref_celltypes, out_fname = 'norm', data_pfx = '/users/abalsubr/wilds/examples/data/encode-tfbs_v1.0/' ): itime = time.time() - sample = np.load(data_pfx + "qn.{}.npy".format(sample_celltype)) + sample = np.load(data_pfx + "qn.{}.npy".format(input_bw_celltype)) ref = np.zeros(len(sample)) for ct in ref_celltypes: ref += (1.0/len(ref_celltypes))*np.load(data_pfx + "qn.{}.npy".format(ct)) @@ -173,7 +172,7 @@ def generate_accessibility_archives(input_dir='dnase_bigwigs', output_dir='codal # Create normalized bigwigs for OOD validation split. for ct in all_celltypes: - dnase_normalize(ct, ct, ref_celltypes) + dnase_normalize(ct, ref_celltypes) # Create normalized bigwig for ID validation split. for ct in ch_test_celltype: - dnase_normalize(ct, ct, ch_test_celltype, out_fname = 'norm_id') + dnase_normalize(ct, ch_test_celltype, out_fname = 'norm_id') From 09c6f6276ac729a049f4bc72c6ef9df29043f3a0 Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Wed, 7 Apr 2021 17:01:14 -0700 Subject: [PATCH 140/244] increasing batch size to 64 --- examples/configs/datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index cf2e6e48..cc9db912 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -297,7 +297,7 @@ 'optimizer': 'Adam', 'optimizer_kwargs': {}, 'scheduler': None, # TODO cosine with warmup from transformers - 'batch_size': 32, #1400, + 'batch_size': 64, #1400, 'lr': 1e-3, 'weight_decay': 1e-5, 'n_epochs': 60, From 932eff999ebdbdb703530933b226b21c3fe79f33 Mon Sep 17 00:00:00 2001 From: aikanor Date: Wed, 7 Apr 2021 18:54:50 -0700 Subject: [PATCH 141/244] change bundle metadata_df --- wilds/datasets/encodetfbs_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index 31a842db..f94cac1b 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -61,7 +61,7 @@ class EncodeTFBSDataset(WILDSDataset): _dataset_name = 'encode-tfbs' _versions_dict = { '1.0': { - 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x370346dfdda44758b75041ca1a5921f4/contents/blob/', + 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x3ae54c5d74524d9b8c45b7cc6c84091c/contents/blob/', 'compressed_size': None}} def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'): @@ -274,7 +274,7 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' allzeroes_mask = (self._y_array.sum(axis=1) == 0).numpy() keep_mask = keep_mask & ~(train_mask & allzeroes_mask) - # Subsample the testing and validation indices + # Subsample the testing and validation indices, to speed up evaluation. # For the OOD splits (val and test), we subsample by a factor of 3 # For the id_val split if it exists, we subsample by a factor of 15 for subsample_seed, (split, subsample_factor) in enumerate( From 0c39bdf259c0c57261fe94210cf3717562e240b4 Mon Sep 17 00:00:00 2001 From: aikanor Date: Wed, 7 Apr 2021 21:42:50 -0700 Subject: [PATCH 142/244] change bundle --- wilds/datasets/encodetfbs_dataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index f94cac1b..6b3e9656 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -61,7 +61,7 @@ class EncodeTFBSDataset(WILDSDataset): _dataset_name = 'encode-tfbs' _versions_dict = { '1.0': { - 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x3ae54c5d74524d9b8c45b7cc6c84091c/contents/blob/', + 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x777a583b83c24e209d54e56c2dfdaa06/contents/blob/', 'compressed_size': None}} def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'): @@ -128,6 +128,7 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' 'id_val': 'Validation (ID)', } elif self._split_scheme == 'in-dist': + dnase_norm_mode = 'norm_id' splits = { 'train': { 'chroms': train_chroms, From b18988880152319226bf7bfb280a4cf8347d5d4a Mon Sep 17 00:00:00 2001 From: aikanor Date: Thu, 8 Apr 2021 16:24:37 -0700 Subject: [PATCH 143/244] refactor challenge split as default --- wilds/datasets/encodetfbs_dataset.py | 149 +++++---------------------- 1 file changed, 27 insertions(+), 122 deletions(-) diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index 6b3e9656..fd691b81 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -9,34 +9,6 @@ from wilds.common.metrics.all_metrics import MultiTaskAveragePrecision -# quantile normalization via numpy inter/extra-polation -def anchor(input_data, sample, ref): # input 1d array - sample.sort() - ref.sort() - # 0. create the mapping function - index = np.array(np.where(np.diff(sample) != 0)) + 1 - index = index.flatten() - x = np.concatenate((np.zeros(1), sample[index])) # domain - y = np.zeros(len(x)) # codomain - for i in np.arange(0,len(index)-1, 1): - start = index[i] - end = index[i+1] - y[i+1] = np.mean(ref[start:end]) - i += 1 - start = index[i] - end = len(ref) - y[i+1] = np.mean(ref[start:end]) - # 1. interpolate - output = np.interp(input_data, x, y) - # 2. extrapolate - degree = 1 # degree of the fitting polynomial - num = 10 # number of positions for extrapolate - f1 = np.poly1d(np.polyfit(sample[-num:],ref[-num:],degree)) -# f2=np.poly1d(np.polyfit(sample[:num],ref[:num],degree)) - output[input_data > sample[-1]] = f1(input_data[input_data > sample[-1]]) -# output[input_data.test., e.g. 'official' is 'val.A549.test.GM12878' elif '.' in self._split_scheme: all_celltypes = train_celltypes + val_celltype + test_celltype @@ -354,31 +259,31 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' super().__init__(root_dir, download, split_scheme) - def norm_signal( - self, - signal, - sample_celltype - ): - x = signal - z = np.concatenate(([0],x,[0])) # pad two zeroes - starts = np.where(np.diff(z) != 0)[0] - ends = starts[1:] - starts = starts[:-1] - vals = x[starts] - if starts[0] != 0: - ends = np.concatenate(([starts[0]],ends)) - starts = np.concatenate(([0],starts)) - vals = np.concatenate(([0],vals)) - if ends[-1] != len(signal): - starts = np.concatenate((starts,[ends[-1]])) - ends = np.concatenate((ends,[len(signal)])) - vals = np.concatenate((vals,[0])) +# def norm_signal( +# self, +# signal, +# sample_celltype +# ): +# x = signal +# z = np.concatenate(([0],x,[0])) # pad two zeroes +# starts = np.where(np.diff(z) != 0)[0] +# ends = starts[1:] +# starts = starts[:-1] +# vals = x[starts] +# if starts[0] != 0: +# ends = np.concatenate(([starts[0]],ends)) +# starts = np.concatenate(([0],starts)) +# vals = np.concatenate(([0],vals)) +# if ends[-1] != len(signal): +# starts = np.concatenate((starts,[ends[-1]])) +# ends = np.concatenate((ends,[len(signal)])) +# vals = np.concatenate((vals,[0])) - vals_anchored = anchor(vals, self._dnase_qnorm_arrays[sample_celltype], self._norm_ref_distr) - vals_arr = np.zeros(ends[-1]) - for i in range(len(starts)): - vals_arr[starts[i]:ends[i]] = vals_anchored[i] - return vals_arr.astype(float) +# vals_anchored = anchor(vals, self._dnase_qnorm_arrays[sample_celltype], self._norm_ref_distr) +# vals_arr = np.zeros(ends[-1]) +# for i in range(len(starts)): +# vals_arr[starts[i]:ends[i]] = vals_anchored[i] +# return vals_arr.astype(float) def get_input(self, idx, window_size=12800): """ From f07d96080afb594dbbd5d1d16c0b0f643fc3e3a0 Mon Sep 17 00:00:00 2001 From: aikanor Date: Thu, 8 Apr 2021 17:56:50 -0700 Subject: [PATCH 144/244] add id_test split --- wilds/datasets/encodetfbs_dataset.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index fd691b81..6dffbb76 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -86,18 +86,24 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' 'chroms': test_chroms, 'celltypes': test_celltype }, + 'id_test': { + 'chroms': test_chroms, + 'celltypes': train_celltypes + } } self._split_dict = { 'train': 0, 'val': 1, 'test': 2, 'id_val': 3, + 'id_test': 4 } self._split_names = { 'train': 'Train', 'val': 'Validation (OOD)', 'test': 'Test', 'id_val': 'Validation (ID)', + 'id_test': 'Test (ID)', } elif self._split_scheme == 'in-dist': dnase_norm_mode = 'norm_id' From a20be41a96f9fcf8f9be8196b41cfe04a255cbee Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Fri, 9 Apr 2021 09:40:43 -0700 Subject: [PATCH 145/244] multistepLR scheduler and new optim defaults for encode --- examples/configs/datasets.py | 9 +++++---- examples/configs/scheduler.py | 5 +++++ examples/configs/supported.py | 2 +- examples/scheduler.py | 8 ++++++-- wilds/datasets/encodetfbs_dataset.py | 24 ++++++++++++------------ 5 files changed, 29 insertions(+), 19 deletions(-) diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index cb9f9421..5168333f 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -117,14 +117,15 @@ 'val_metric': 'avgprec-macro_all', 'val_metric_decreasing': False, 'optimizer': 'Adam', - 'scheduler': None, + 'scheduler': 'MultiStepLR', + 'scheduler_kwargs': {'milestones':[3,6], 'gamma': 0.1}, 'batch_size': 128, - 'lr': 1e-5, + 'lr': 1e-3, 'weight_decay': 1e-4, - 'n_epochs': 10, + 'n_epochs': 9, 'n_groups_per_batch': 2, 'algo_log_metric': 'multitask_binary_accuracy', - # 'irm_lambda': 1.0, + 'irm_lambda': 100.0, # 'coral_penalty_weight': 0.1, }, 'fmow': { diff --git a/examples/configs/scheduler.py b/examples/configs/scheduler.py index 4ae29b10..e342d156 100644 --- a/examples/configs/scheduler.py +++ b/examples/configs/scheduler.py @@ -12,4 +12,9 @@ 'step_size': 1, } }, + 'MultiStepLR': { + 'scheduler_kwargs':{ + 'gamma': 0.1, + } + }, } diff --git a/examples/configs/supported.py b/examples/configs/supported.py index a49732a8..35c44741 100644 --- a/examples/configs/supported.py +++ b/examples/configs/supported.py @@ -36,4 +36,4 @@ 'gin-virtual', 'logistic_regression', 'code-gpt-py', 'unet-seq'] algorithms = ['ERM', 'groupDRO', 'deepCORAL', 'IRM'] optimizers = ['SGD', 'Adam', 'AdamW'] -schedulers = ['linear_schedule_with_warmup', 'ReduceLROnPlateau', 'StepLR'] +schedulers = ['linear_schedule_with_warmup', 'ReduceLROnPlateau', 'StepLR', 'MultiStepLR'] diff --git a/examples/scheduler.py b/examples/scheduler.py index 7b966624..6fc3b3d2 100644 --- a/examples/scheduler.py +++ b/examples/scheduler.py @@ -1,5 +1,5 @@ from transformers import get_linear_schedule_with_warmup -from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR +from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR, MultiStepLR def initialize_scheduler(config, optimizer, n_train_steps): # construct schedulers @@ -23,9 +23,13 @@ def initialize_scheduler(config, optimizer, n_train_steps): scheduler = StepLR(optimizer, **config.scheduler_kwargs) step_every_batch = False use_metric = False + elif config.scheduler == 'MultiStepLR': + scheduler = MultiStepLR(optimizer, **config.scheduler_kwargs) + step_every_batch = False + use_metric = False else: raise ValueError('Scheduler not recognized.') - # add an step_every_batch field + # add a step_every_batch field scheduler.step_every_batch = step_every_batch scheduler.use_metric = use_metric return scheduler diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index 6b3e9656..b8a5b0ab 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -47,8 +47,8 @@ class EncodeTFBSDataset(WILDSDataset): 12800-base-pair regions of sequence with a quantified chromatin accessibility readout. Label (y): - y is a 128-bit vector, with each element y_i indicating the binding status of a 200bp window. It is 1 if this 200bp region is bound by the transcription factor, and 0 otherwise, for i = 0,1,...,127. - + y is a 128-bit vector, with each element y_i indicating the binding status of a 200bp window. It is 1 if this 200bp region is bound by the transcription factor, and 0 otherwise, for i = 0,1,...,127. + Concretely, suppose the input window x starts at coordinate sc, extending until coordinate (sc+12800). Then y_i is the label of the window starting at coordinate (sc+3200)+(50*i). Metadata: @@ -85,9 +85,9 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' # This typically happens at the flanking regions of peaks. # For our purposes, we will ignore these ambiguous labels during training and eval. self.y_array[self.y_array == 0.5] = float('nan') - + dnase_norm_mode = 'norm' - + # Construct splits train_chroms = ['chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr10', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX'] val_chroms = ['chr2', 'chr9', 'chr11'] @@ -220,7 +220,7 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' 'val': 'Validation (OOD)', 'test': 'Test', } - # Add new split scheme specifying custom test and val celltypes in the format val..test., e.g. 'official' is 'val.A549.test.GM12878' + # Add new split scheme specifying custom test and val celltypes in the format val..test., e.g. 'official' is 'val.A549.test.GM12878' elif '.' in self._split_scheme: all_celltypes = train_celltypes + val_celltype + test_celltype in_val_ct = self._split_scheme.split('.')[1] @@ -275,7 +275,7 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' allzeroes_mask = (self._y_array.sum(axis=1) == 0).numpy() keep_mask = keep_mask & ~(train_mask & allzeroes_mask) - # Subsample the testing and validation indices, to speed up evaluation. + # Subsample the testing and validation indices, to speed up evaluation. # For the OOD splits (val and test), we subsample by a factor of 3 # For the id_val split if it exists, we subsample by a factor of 15 for subsample_seed, (split, subsample_factor) in enumerate( @@ -316,11 +316,11 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' dnase_bw_path = os.path.join(self._data_dir, 'DNase/{}.bigwig'.format(ct)) """ dnase_bw_path = os.path.join( - self._data_dir, + self._data_dir, 'DNase.{}.{}.bigwig'.format(ct, dnase_norm_mode) ) self._dnase_allcelltypes[ct] = pyBigWig.open(dnase_bw_path) - + # Load subsampled DNase arrays for normalization purposes self._dnase_qnorm_arrays = {} for ct in self._all_celltypes: @@ -355,8 +355,8 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' super().__init__(root_dir, download, split_scheme) def norm_signal( - self, - signal, + self, + signal, sample_celltype ): x = signal @@ -373,13 +373,13 @@ def norm_signal( starts = np.concatenate((starts,[ends[-1]])) ends = np.concatenate((ends,[len(signal)])) vals = np.concatenate((vals,[0])) - + vals_anchored = anchor(vals, self._dnase_qnorm_arrays[sample_celltype], self._norm_ref_distr) vals_arr = np.zeros(ends[-1]) for i in range(len(starts)): vals_arr[starts[i]:ends[i]] = vals_anchored[i] return vals_arr.astype(float) - + def get_input(self, idx, window_size=12800): """ Returns x for a given idx in metadata_array, which has been filtered to only take windows with the desired stride. From da05aae64dc792c80f538c607952ff455d98474c Mon Sep 17 00:00:00 2001 From: aikanor Date: Wed, 14 Apr 2021 18:01:12 -0700 Subject: [PATCH 146/244] add buggy version of dynamic normalization --- dataset_preprocessing/encode-tfbs/README.md | 11 +- examples/models/CNN_genome.py | 2 +- wilds/datasets/encodetfbs_dataset.py | 140 ++++++++++++++++++-- 3 files changed, 134 insertions(+), 19 deletions(-) diff --git a/dataset_preprocessing/encode-tfbs/README.md b/dataset_preprocessing/encode-tfbs/README.md index b4f806ec..e0e39cc1 100644 --- a/dataset_preprocessing/encode-tfbs/README.md +++ b/dataset_preprocessing/encode-tfbs/README.md @@ -9,11 +9,12 @@ 2. Run `python prep_sequence.py --seq_path SEQUENCE_PATH --output_dir OUTPUT_DIR` to write the fasta file found in `SEQUENCE_PATH` to a numpy array archive in `OUTPUT_DIR`. -3. Download the DNase accessibility data. This consists of whole-genome DNase files in bigwig format from https://guanfiles.dcmb.med.umich.edu/Leopard/dnase_bigwig/. +3. Download the DNase accessibility data. This consists of whole-genome DNase files in bigwig format from https://guanfiles.dcmb.med.umich.edu/Leopard/dnase_bigwig/. These are saved with filename `DNASE..fc.signal.bigwig`. -4. Download the labels from the challenge into a label directory created for this purpose: - - The training labels from https://www.synapse.org/#!Synapse:syn7413983 for the relevant transcription factor ( https://www.synapse.org/#!Synapse:syn7415202 for the TF MAX). - - The validation labels from https://www.synapse.org/#!Synapse:syn8441154 for the relevant transcription factor ( https://www.synapse.org/#!Synapse:syn8442103 for the TF MAX). - - (Optional) The validation labels for the challenge's evaluation cell type (liver) from https://www.synapse.org/#!Synapse:syn8442975 for the relevant transcription factor ( https://www.synapse.org/#!Synapse:syn8443021 for the TF MAX). +4. Download the labels from the challenge into a label directory `labels/` created for this purpose: + - The training chromosome labels for the challenge's training cell types from https://www.synapse.org/#!Synapse:syn7413983 for the relevant transcription factor ( https://www.synapse.org/#!Synapse:syn7415202 for the TF MAX, , downloaded as MAX.train.labels.tsv.gz ). + - The training chromosome labels for the challenge's evaluation cell type (liver) from https://www.synapse.org/#!Synapse:syn8077511 for the relevant transcription factor ( https://www.synapse.org/#!Synapse:syn8077648 for the TF MAX, downloaded as MAX.train_wc.labels.tsv.gz ). + - The validation chromosome labels for the challenge's training cell types from https://www.synapse.org/#!Synapse:syn8441154 for the relevant transcription factor ( https://www.synapse.org/#!Synapse:syn8442103 for the TF MAX, downloaded as MAX.val.labels.tsv.gz ). + - The validation chromosome labels for the challenge's evaluation cell type (liver) from https://www.synapse.org/#!Synapse:syn8442975 for the relevant transcription factor ( https://www.synapse.org/#!Synapse:syn8443021 for the TF MAX, downloaded as MAX.test.labels.tsv.gz ). 5. Run `prep_metadata_labels.py`. diff --git a/examples/models/CNN_genome.py b/examples/models/CNN_genome.py index a10404d6..7c8a5be8 100644 --- a/examples/models/CNN_genome.py +++ b/examples/models/CNN_genome.py @@ -38,7 +38,7 @@ def __init__(self, num_tasks=16, n_channels_in=5): self.maxpool = nn.MaxPool1d(2) # self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) - self.conv_middle = single_conv(109, 109) + # self.conv_middle = single_conv(109, 109) self.upsamp_6 = nn.ConvTranspose1d(109, 109, 2, stride=2) self.dconv_up5 = double_conv(73 + 109, 73) diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index 486adeed..2bb6c6a3 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -9,6 +9,34 @@ from wilds.common.metrics.all_metrics import MultiTaskAveragePrecision +# quantile normalization via numpy inter/extra-polation +def anchor(input_data, sample, ref): # input 1d array + sample.sort() + ref.sort() + # 0. create the mapping function + index = np.array(np.where(np.diff(sample) != 0)) + 1 + index = index.flatten() + x = np.concatenate((np.zeros(1), sample[index])) # domain + y = np.zeros(len(x)) # codomain + for i in np.arange(0,len(index)-1, 1): + start = index[i] + end = index[i+1] + y[i+1] = np.mean(ref[start:end]) + i += 1 + start = index[i] + end = len(ref) + y[i+1] = np.mean(ref[start:end]) + # 1. interpolate + output = np.interp(input_data, x, y) + # 2. extrapolate + degree = 1 # degree of the fitting polynomial + num = 10 # number of positions for extrapolate + f1 = np.poly1d(np.polyfit(sample[-num:],ref[-num:],degree)) +# f2=np.poly1d(np.polyfit(sample[:num],ref[:num],degree)) + output[input_data > sample[-1]] = f1(input_data[input_data > sample[-1]]) +# output[input_data.test., e.g. 'official' is 'val.HepG2.test.liver' + elif '.' in self._split_scheme: + all_celltypes = train_celltypes + val_celltype + test_celltype + in_val_ct = self._split_scheme.split('.')[1] + in_test_ct = self._split_scheme.split('.')[3] + train_celltypes = [ct for ct in all_celltypes if ((ct != in_val_ct) and (ct != in_test_ct))] + val_celltype = [in_val_ct] + test_celltype = [in_test_ct] + splits = { + 'train': { + 'chroms': train_chroms, + 'celltypes': train_celltypes + }, + 'id_val': { + 'chroms': val_chroms, + 'celltypes': train_celltypes + }, + 'val': { + 'chroms': val_chroms, + 'celltypes': val_celltype + }, + 'test': { + 'chroms': test_chroms, + 'celltypes': test_celltype + }, + } + self._split_dict = { + 'train': 0, + 'val': 1, + 'test': 2, + 'id_val': 3, + } + self._split_names = { + 'train': 'Train', + 'val': 'Validation (OOD)', + 'test': 'Test', + 'id_val': 'Validation (ID)', + } else: raise ValueError(f'Split scheme {self._split_scheme} not recognized') @@ -182,16 +275,8 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' # Set up file handles for DNase features self._dnase_allcelltypes = {} for ct in self._all_celltypes: - """ - if 'challenge' in self._split_scheme: - dnase_bw_path = os.path.join(self._data_dir, 'DNASE.{}.fc.signal.bigwig'.format(ct)) - else: - dnase_bw_path = os.path.join(self._data_dir, 'DNase/{}.bigwig'.format(ct)) - """ - dnase_bw_path = os.path.join( - self._data_dir, - 'DNase.{}.{}.bigwig'.format(ct, dnase_norm_mode) - ) + dnase_bw_path = os.path.join(self._data_dir, 'DNASE.{}.fc.signal.bigwig'.format(ct)) + # dnase_bw_path = os.path.join(self._data_dir, 'DNase.{}.{}.bigwig'.format(ct, dnase_norm_mode)) self._dnase_allcelltypes[ct] = pyBigWig.open(dnase_bw_path) # Load subsampled DNase arrays for normalization purposes @@ -226,7 +311,33 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' self._metric = MultiTaskAveragePrecision() super().__init__(root_dir, download, split_scheme) - + + def norm_signal( + self, + signal, + sample_celltype + ): + x = signal + z = np.concatenate(([0],x,[0])) # pad two zeroes + starts = np.where(np.diff(z) != 0)[0] + ends = starts[1:] + starts = starts[:-1] + vals = x[starts] + if starts[0] != 0: + ends = np.concatenate(([starts[0]],ends)) + starts = np.concatenate(([0],starts)) + vals = np.concatenate(([0],vals)) + if ends[-1] != len(signal): + starts = np.concatenate((starts,[ends[-1]])) + ends = np.concatenate((ends,[len(signal)])) + vals = np.concatenate((vals,[0])) + + vals_anchored = anchor(vals, self._dnase_qnorm_arrays[sample_celltype], self._norm_ref_distr) + vals_arr = np.zeros(ends[-1]) + for i in range(len(starts)): + vals_arr[starts[i]:ends[i]] = vals_anchored[i] + return vals_arr.astype(float) + def get_input(self, idx, window_size=12800): """ Returns x for a given idx in metadata_array, which has been filtered to only take windows with the desired stride. @@ -243,7 +354,10 @@ def get_input(self, idx, window_size=12800): seq_this = self._seq_bp[this_metadata['chr']][interval_start:interval_end] dnase_bw = self._dnase_allcelltypes[this_metadata['celltype']] dnase_this = np.nan_to_num(dnase_bw.values(chrom, interval_start, interval_end, numpy=True)) - + + assert(np.isnan(seq_this).sum() == 0) + assert(np.isnan(dnase_this).sum() == 0) + dnase_this = self.norm_signal(dnase_this, this_metadata['celltype']) return torch.tensor(np.column_stack( [seq_this, dnase_this] From a20ba550643bbdf31ecb49bb918aa8b6cedfc668 Mon Sep 17 00:00:00 2001 From: aikanor Date: Wed, 14 Apr 2021 18:05:15 -0700 Subject: [PATCH 147/244] change bundle --- wilds/datasets/encodetfbs_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index 2bb6c6a3..b1d572cb 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -61,7 +61,7 @@ class EncodeTFBSDataset(WILDSDataset): _dataset_name = 'encode-tfbs' _versions_dict = { '1.0': { - 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x777a583b83c24e209d54e56c2dfdaa06/contents/blob/', + 'download_url': 'https://worksheets.codalab.org/rest/bundles/0xf0a83ce649c540b39149250dc8e3c66b/contents/blob/', 'compressed_size': None}} def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'): From b1a3643d8bf1ec5d429606e8f1984ed5e93f8789 Mon Sep 17 00:00:00 2001 From: aikanor Date: Wed, 14 Apr 2021 18:14:43 -0700 Subject: [PATCH 148/244] fix float-double type error --- examples/models/CNN_genome.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/models/CNN_genome.py b/examples/models/CNN_genome.py index 7c8a5be8..a5711a48 100644 --- a/examples/models/CNN_genome.py +++ b/examples/models/CNN_genome.py @@ -59,6 +59,7 @@ def __init__(self, num_tasks=16, n_channels_in=5): def forward(self, x): # input_size = 12800 # input_channels = 5 + x = x.float() conv1 = self.dconv_down1(x) # Output size: (input_size) x 15 x = self.maxpool(conv1) # (input_size / 2) x 15 From 84033d79dcee120d7b85847255f13bcbed691e3f Mon Sep 17 00:00:00 2001 From: aikanor Date: Thu, 15 Apr 2021 13:31:33 -0700 Subject: [PATCH 149/244] add workaround for dynamic normalization --- wilds/datasets/encodetfbs_dataset.py | 106 ++++++++++++++++++--------- 1 file changed, 73 insertions(+), 33 deletions(-) diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index b1d572cb..c4fee9cf 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -8,6 +8,8 @@ from wilds.common.grouper import CombinatorialGrouper from wilds.common.metrics.all_metrics import MultiTaskAveragePrecision +# Human chromosomes in hg19 +chrom_sizes = {'chr1': 249250621, 'chr10': 135534747, 'chr11': 135006516, 'chr12': 133851895, 'chr13': 115169878, 'chr14': 107349540, 'chr15': 102531392, 'chr16': 90354753, 'chr17': 81195210, 'chr18': 78077248, 'chr19': 59128983, 'chr2': 243199373, 'chr20': 63025520, 'chr21': 48129895, 'chr22': 51304566, 'chr3': 198022430, 'chr4': 191154276, 'chr5': 180915260, 'chr6': 171115067, 'chr7': 159138663, 'chr8': 146364022, 'chr9': 141213431, 'chrX': 155270560} # quantile normalization via numpy inter/extra-polation def anchor(input_data, sample, ref): # input 1d array @@ -38,6 +40,67 @@ def anchor(input_data, sample, ref): # input 1d array return output +def wrap_anchor( + signal, + sample, + ref +): + ## 1.format as bigwig first + x = signal + z = np.concatenate(([0],x,[0])) # pad two zeroes + # find boundary + starts = np.where(np.diff(z) != 0)[0] + ends = starts[1:] + starts = starts[:-1] + vals = x[starts] + if starts[0] != 0: + ends = np.concatenate(([starts[0]],ends)) + starts = np.concatenate(([0],starts)) + vals = np.concatenate(([0],vals)) + if ends[-1] != len(signal): + starts = np.concatenate((starts,[ends[-1]])) + ends = np.concatenate((ends,[len(signal)])) + vals = np.concatenate((vals,[0])) + + ## 2.then quantile normalization + vals_anchored = anchor(vals, sample, ref) + return vals_anchored, starts, ends + + +def dnase_normalize( + input_bw_celltype, + ref_celltypes, + out_fname = 'norm', + data_pfx = '/users/abalsubr/wilds/examples/data/encode-tfbs_v1.0/' +): + if not data_pfx.endswith('/'): + data_pfx = data_pfx + '/' + itime = time.time() + sample = np.load(data_pfx + "qn.{}.npy".format(input_bw_celltype)) + ref = np.zeros(len(sample)) + for ct in ref_celltypes: + ref += (1.0/len(ref_celltypes))*np.load(data_pfx + "qn.{}.npy".format(ct)) + + chromsizes_list = [(k, v) for k, v in chrom_sizes.items()] + out_fname = data_pfx + 'DNase.{}.{}.bigwig'.format(input_bw_celltype, out_fname) + bw_output = pyBigWig.open(out_fname, 'w') + bw_output.addHeader(chromsizes_list) + # bw_output.addHeader(list(zip(chr_all , num_bp)), maxZooms=0) # zip two turples + + for the_chr in chrom_sizes: + signal = np.zeros(chrom_sizes[the_chr]) + bw = pyBigWig.open(data_pfx + 'DNASE.{}.fc.signal.bigwig'.format(input_bw_celltype)) + signal += np.nan_to_num(np.array(bw.values(the_chr, 0, chrom_sizes[the_chr]))) + bw.close() + vals_anchored, starts, ends = wrap_anchor(signal, sample, ref) + # write normalized dnase file. + chroms = np.array([the_chr] * len(vals_anchored)) + bw_output.addEntries(chroms, starts, ends=ends, values=vals_anchored) + print(input_bw_celltype, the_chr, time.time() - itime) + + bw_output.close() + + class EncodeTFBSDataset(WILDSDataset): """ ENCODE-DREAM-wilds dataset of transcription factor binding sites. @@ -271,12 +334,15 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' self._seq_bp[chrom] = seq_arr[chrom] print(chrom, time.time() - itime) del seq_arr - - # Set up file handles for DNase features + + # Set up file handles for DNase features, writing normalized DNase tracks along the way. self._dnase_allcelltypes = {} for ct in self._all_celltypes: - dnase_bw_path = os.path.join(self._data_dir, 'DNASE.{}.fc.signal.bigwig'.format(ct)) - # dnase_bw_path = os.path.join(self._data_dir, 'DNase.{}.{}.bigwig'.format(ct, dnase_norm_mode)) + orig_dnase_bw_path = os.path.join(self._data_dir, 'DNASE.{}.fc.signal.bigwig'.format(ct)) + dnase_bw_path = os.path.join(self._data_dir, 'DNase.{}.{}.bigwig'.format(ct, self._split_scheme)) + if not os.path.exists(dnase_bw_path): + ref_celltypes = splits['train']['celltypes'] + dnase_normalize(ct, ref_celltypes, out_fname=self._split_scheme, data_pfx=self._data_dir) self._dnase_allcelltypes[ct] = pyBigWig.open(dnase_bw_path) # Load subsampled DNase arrays for normalization purposes @@ -312,32 +378,6 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' super().__init__(root_dir, download, split_scheme) - def norm_signal( - self, - signal, - sample_celltype - ): - x = signal - z = np.concatenate(([0],x,[0])) # pad two zeroes - starts = np.where(np.diff(z) != 0)[0] - ends = starts[1:] - starts = starts[:-1] - vals = x[starts] - if starts[0] != 0: - ends = np.concatenate(([starts[0]],ends)) - starts = np.concatenate(([0],starts)) - vals = np.concatenate(([0],vals)) - if ends[-1] != len(signal): - starts = np.concatenate((starts,[ends[-1]])) - ends = np.concatenate((ends,[len(signal)])) - vals = np.concatenate((vals,[0])) - - vals_anchored = anchor(vals, self._dnase_qnorm_arrays[sample_celltype], self._norm_ref_distr) - vals_arr = np.zeros(ends[-1]) - for i in range(len(starts)): - vals_arr[starts[i]:ends[i]] = vals_anchored[i] - return vals_arr.astype(float) - def get_input(self, idx, window_size=12800): """ Returns x for a given idx in metadata_array, which has been filtered to only take windows with the desired stride. @@ -355,9 +395,9 @@ def get_input(self, idx, window_size=12800): dnase_bw = self._dnase_allcelltypes[this_metadata['celltype']] dnase_this = np.nan_to_num(dnase_bw.values(chrom, interval_start, interval_end, numpy=True)) - assert(np.isnan(seq_this).sum() == 0) - assert(np.isnan(dnase_this).sum() == 0) - dnase_this = self.norm_signal(dnase_this, this_metadata['celltype']) +# assert(np.isnan(seq_this).sum() == 0) +# assert(np.isnan(dnase_this).sum() == 0) +# dnase_this = self.norm_signal(dnase_this, this_metadata['celltype']) return torch.tensor(np.column_stack( [seq_this, dnase_this] From 274900a693db27b432594e26a467294df31c0b24 Mon Sep 17 00:00:00 2001 From: Berton Earnshaw Date: Mon, 19 Apr 2021 17:30:22 -0600 Subject: [PATCH 150/244] Add cosine lr scheduler and params --- examples/configs/datasets.py | 9 +++++---- examples/configs/supported.py | 2 +- examples/scheduler.py | 10 +++++++++- 3 files changed, 15 insertions(+), 6 deletions(-) diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index cc9db912..3bcc1326 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -296,11 +296,12 @@ 'algo_log_metric': 'accuracy', 'optimizer': 'Adam', 'optimizer_kwargs': {}, - 'scheduler': None, # TODO cosine with warmup from transformers - 'batch_size': 64, #1400, - 'lr': 1e-3, + 'scheduler': 'cosine_schedule_with_warmup', + 'scheduler_kwargs': {'num_warmup_steps': 5415}, + 'batch_size': 75, + 'lr': 1e-4, 'weight_decay': 1e-5, - 'n_epochs': 60, + 'n_epochs': 90, 'process_outputs_function': 'multiclass_logits_to_pred', }, } diff --git a/examples/configs/supported.py b/examples/configs/supported.py index 8b66b74e..57f89a42 100644 --- a/examples/configs/supported.py +++ b/examples/configs/supported.py @@ -34,4 +34,4 @@ 'gin-virtual', 'logistic_regression', 'code-gpt-py'] algorithms = ['ERM', 'groupDRO', 'deepCORAL', 'IRM'] optimizers = ['SGD', 'Adam', 'AdamW'] -schedulers = ['linear_schedule_with_warmup', 'ReduceLROnPlateau', 'StepLR'] +schedulers = ['linear_schedule_with_warmup', 'cosine_schedule_with_warmup', 'ReduceLROnPlateau', 'StepLR'] diff --git a/examples/scheduler.py b/examples/scheduler.py index 7b966624..d025a7b8 100644 --- a/examples/scheduler.py +++ b/examples/scheduler.py @@ -1,4 +1,5 @@ -from transformers import get_linear_schedule_with_warmup +from transformers import (get_linear_schedule_with_warmup, + get_cosine_schedule_with_warmup) from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR def initialize_scheduler(config, optimizer, n_train_steps): @@ -12,6 +13,13 @@ def initialize_scheduler(config, optimizer, n_train_steps): **config.scheduler_kwargs) step_every_batch = True use_metric = False + elif config.scheduler == 'cosine_schedule_with_warmup': + scheduler = get_cosine_schedule_with_warmup( + optimizer, + num_training_steps=n_train_steps, + **config.scheduler_kwargs) + step_every_batch = True + use_metric = False elif config.scheduler=='ReduceLROnPlateau': assert config.scheduler_metric_name, f'scheduler metric must be specified for {config.scheduler}' scheduler = ReduceLROnPlateau( From 8e675cfa70f9a696ff70e3aa3bab121a94d6b77a Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Mon, 19 Apr 2021 16:42:00 -0700 Subject: [PATCH 151/244] add configs/scheduler key for cosine schedule --- examples/configs/scheduler.py | 5 +++++ wilds/common/data_loaders.py | 17 ++++++++--------- wilds/datasets/rxrx1_dataset.py | 2 +- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/examples/configs/scheduler.py b/examples/configs/scheduler.py index 4ae29b10..e8a80612 100644 --- a/examples/configs/scheduler.py +++ b/examples/configs/scheduler.py @@ -4,6 +4,11 @@ 'num_warmup_steps': 0, }, }, + 'cosine_schedule_with_warmup': { + 'scheduler_kwargs':{ + 'num_warmup_steps': 0, + }, + }, 'ReduceLROnPlateau': { 'scheduler_kwargs':{}, }, diff --git a/wilds/common/data_loaders.py b/wilds/common/data_loaders.py index b806832a..76353678 100644 --- a/wilds/common/data_loaders.py +++ b/wilds/common/data_loaders.py @@ -4,24 +4,24 @@ from torch.utils.data.sampler import WeightedRandomSampler, SubsetRandomSampler from wilds.common.utils import get_counts, split_into_groups -def get_train_loader(loader, dataset, batch_size, +def get_train_loader(loader, dataset, batch_size, uniform_over_groups=None, grouper=None, distinct_groups=True, n_groups_per_batch=None, **loader_kwargs): """ Constructs and returns the data loader for training. Args: - loader (str): Loader type. 'standard' for standard loaders and 'group' for group loaders, - which first samples groups and then samples a fixed number of examples belonging + which first samples groups and then samples a fixed number of examples belonging to each group. - dataset (WILDSDataset or WILDSSubset): Data - batch_size (int): Batch size - - uniform_over_groups (None or bool): Whether to sample the groups uniformly or according to the - natural data distribution. - Setting to None applies the defaults for each type of loaders. - For standard loaders, the default is False. For group loaders, + - uniform_over_groups (None or bool): Whether to sample the groups uniformly or according + to the natural data distribution. + Setting to None applies the defaults for each type of loaders. + For standard loaders, the default is False. For group loaders, the default is True. - grouper (Grouper): Grouper used for group loaders or for uniform_over_groups=True - distinct_groups (bool): Whether to sample distinct_groups within each minibatch for group loaders. - - n_groups_poer_batch (int): Number of groups to sample in each minibatch for group loaders. + - n_groups_per_batch (int): Number of groups to sample in each minibatch for group loaders. - loader_kwargs: kwargs passed into torch DataLoader initialization. Output: - data loader (DataLoader): Data loader. @@ -30,7 +30,6 @@ def get_train_loader(loader, dataset, batch_size, if uniform_over_groups is None or not uniform_over_groups: return DataLoader( dataset, - # shuffle=False, # Shuffle training dataset shuffle=True, # Shuffle training dataset sampler=None, collate_fn=dataset.collate, @@ -82,7 +81,7 @@ def get_eval_loader(loader, dataset, batch_size, grouper=None, **loader_kwargs): """ Constructs and returns the data loader for evaluation. Args: - - loader (str): Loader type. 'standard' for standard loaders. + - loader (str): Loader type. 'standard' for standard loaders. - dataset (WILDSDataset or WILDSSubset): Data - batch_size (int): Batch size - loader_kwargs: kwargs passed into torch DataLoader initialization. diff --git a/wilds/datasets/rxrx1_dataset.py b/wilds/datasets/rxrx1_dataset.py index c7468462..d83b72d4 100644 --- a/wilds/datasets/rxrx1_dataset.py +++ b/wilds/datasets/rxrx1_dataset.py @@ -56,7 +56,7 @@ class RxRx1Dataset(WILDSDataset): 'compressed_size': 7_413_123_845} } - def __init__(self, version=None, root_dir='rxrx1-wilds', download=False, + def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'): self._version = version From 8a6214e5c7cf15aa4e082294b7f8a73b33377b41 Mon Sep 17 00:00:00 2001 From: Berton Earnshaw Date: Tue, 20 Apr 2021 07:25:25 -0600 Subject: [PATCH 152/244] Use rxrx1 transforms --- examples/configs/datasets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index 3bcc1326..900ab640 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -286,8 +286,8 @@ 'split_scheme': 'official', 'model': 'resnet50', 'model_kwargs': {'pretrained': True}, - 'train_transform': 'image_base', - 'eval_transform': 'image_base', + 'train_transform': 'rxrx1', + 'eval_transform': 'rxrx1', 'target_resolution': (256, 256), 'loss_function': 'cross_entropy', 'groupby_fields': ['experiment'], From 60c0bfb407ec7aa743bcb5b835ea67b01bb197ce Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Tue, 20 Apr 2021 09:58:20 -0700 Subject: [PATCH 153/244] merge train and eval transforms and add support for rxrx1 transform --- examples/configs/datasets.py | 36 +++++++++---------------- examples/configs/supported.py | 2 +- examples/run_expt.py | 13 ++++----- examples/transforms.py | 50 +++++++++++++++++------------------ 4 files changed, 45 insertions(+), 56 deletions(-) diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index 3bcc1326..f2bff675 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -2,8 +2,7 @@ 'amazon': { 'split_scheme': 'official', 'model': 'distilbert-base-uncased', - 'train_transform': 'bert', - 'eval_transform': 'bert', + 'transform': 'bert', 'max_token_length': 512, 'loss_function': 'cross_entropy', 'algo_log_metric': 'accuracy', @@ -34,16 +33,14 @@ 'weight_decay': 0.0001, 'n_epochs': 10, 'algo_log_metric': 'multitask_binary_accuracy', - 'train_transform': 'image_base', - 'eval_transform': 'image_base', + 'transform': 'image_base', 'process_outputs_function': 'binary_logits_to_pred', }, 'camelyon17': { 'split_scheme': 'official', 'model': 'densenet121', 'model_kwargs': {'pretrained': False}, - 'train_transform': 'image_base', - 'eval_transform': 'image_base', + 'transform': 'image_base', 'target_resolution': (96, 96), 'loss_function': 'cross_entropy', 'groupby_fields': ['hospital'], @@ -66,8 +63,7 @@ 'split_scheme': 'official', 'model': 'resnet50', 'model_kwargs': {'pretrained': True}, - 'train_transform': 'image_base', - 'eval_transform': 'image_base', + 'transform': 'image_base', 'loss_function': 'cross_entropy', 'groupby_fields': ['male', 'y'], 'val_metric': 'acc_wg', @@ -85,8 +81,7 @@ 'civilcomments': { 'split_scheme': 'official', 'model': 'distilbert-base-uncased', - 'train_transform': 'bert', - 'eval_transform': 'bert', + 'transform': 'bert', 'loss_function': 'cross_entropy', 'groupby_fields': ['black', 'y'], 'val_metric': 'acc_wg', @@ -114,8 +109,7 @@ }, 'model': 'densenet121', 'model_kwargs': {'pretrained': True}, - 'train_transform': 'image_base', - 'eval_transform': 'image_base', + 'transform': 'image_base', 'loss_function': 'cross_entropy', 'groupby_fields': ['year',], 'val_metric': 'acc_worst_region', @@ -137,8 +131,7 @@ 'loss_function': 'cross_entropy', 'val_metric': 'F1-macro_all', 'model_kwargs': {'pretrained': True}, - 'train_transform': 'image_base', - 'eval_transform': 'image_base', + 'transform': 'image_base', 'target_resolution': (448, 448), 'val_metric_decreasing': False, 'algo_log_metric': 'accuracy', @@ -207,8 +200,7 @@ }, 'model': 'resnet18_ms', 'model_kwargs': {'num_channels': 8}, - 'train_transform': 'poverty_train', - 'eval_transform': None, + 'transform': 'poverty', 'loss_function': 'mse', 'groupby_fields': ['country',], 'val_metric': 'r_wg', @@ -229,8 +221,7 @@ 'waterbirds': { 'split_scheme': 'official', 'model': 'resnet50', - 'train_transform': 'image_resize_and_center_crop', - 'eval_transform': 'image_resize_and_center_crop', + 'transform': 'image_resize_and_center_crop', 'resize_scale': 256.0/224.0, 'model_kwargs': {'pretrained': True}, 'loss_function': 'cross_entropy', @@ -250,8 +241,7 @@ 'yelp': { 'split_scheme': 'official', 'model': 'bert-base-uncased', - 'train_transform': 'bert', - 'eval_transform': 'bert', + 'transform': 'bert', 'max_token_length': 512, 'loss_function': 'cross_entropy', 'algo_log_metric': 'accuracy', @@ -265,8 +255,7 @@ 'sqf': { 'split_scheme': 'all_race', 'model': 'logistic_regression', - 'train_transform': None, - 'eval_transform': None, + 'transform': None, 'model_kwargs': {'in_features': 104}, 'loss_function': 'cross_entropy', 'groupby_fields': ['y'], @@ -286,8 +275,7 @@ 'split_scheme': 'official', 'model': 'resnet50', 'model_kwargs': {'pretrained': True}, - 'train_transform': 'image_base', - 'eval_transform': 'image_base', + 'transform': 'rxrx1', 'target_resolution': (256, 256), 'loss_function': 'cross_entropy', 'groupby_fields': ['experiment'], diff --git a/examples/configs/supported.py b/examples/configs/supported.py index 57f89a42..384404b7 100644 --- a/examples/configs/supported.py +++ b/examples/configs/supported.py @@ -28,7 +28,7 @@ } # see initialize_*() functions for correspondence -transforms = ['bert', 'image_base', 'image_resize_and_center_crop', 'poverty_train'] +transforms = ['bert', 'image_base', 'image_resize_and_center_crop', 'poverty', 'rxrx1'] models = ['resnet18_ms', 'resnet50', 'resnet34', 'wideresnet50', 'densenet121', 'bert-base-uncased', 'distilbert-base-uncased', 'gin-virtual', 'logistic_regression', 'code-gpt-py'] diff --git a/examples/run_expt.py b/examples/run_expt.py index 173603ab..990b422e 100644 --- a/examples/run_expt.py +++ b/examples/run_expt.py @@ -53,8 +53,7 @@ def main(): help='keyword arguments for model initialization passed as key1=value1 key2=value2') # Transforms - parser.add_argument('--train_transform', choices=supported.transforms) - parser.add_argument('--eval_transform', choices=supported.transforms) + parser.add_argument('--transform', choices=supported.transforms) parser.add_argument('--target_resolution', nargs='+', type=int, help='The input resolution that images will be resized to before being passed into the model. For example, use --target_resolution 224 224 for a standard ResNet.') parser.add_argument('--resize_scale', type=float) parser.add_argument('--max_token_length', type=int) @@ -148,13 +147,15 @@ def main(): # To implement data augmentation (i.e., have different transforms # at training time vs. test time), modify these two lines: train_transform = initialize_transform( - transform_name=config.train_transform, + transform_name=config.transform, config=config, - dataset=full_dataset) + dataset=full_dataset, + is_training=True) eval_transform = initialize_transform( - transform_name=config.eval_transform, + transform_name=config.transform, config=config, - dataset=full_dataset) + dataset=full_dataset, + is_training=False) train_grouper = CombinatorialGrouper( dataset=full_dataset, diff --git a/examples/transforms.py b/examples/transforms.py index 2b14eedf..7d2dd977 100644 --- a/examples/transforms.py +++ b/examples/transforms.py @@ -5,7 +5,7 @@ from transformers import BertTokenizerFast, DistilBertTokenizerFast import torch -def initialize_transform(transform_name, config, dataset): +def initialize_transform(transform_name, config, dataset, is_training): if transform_name is None: return None elif transform_name=='bert': @@ -14,10 +14,10 @@ def initialize_transform(transform_name, config, dataset): return initialize_image_base_transform(config, dataset) elif transform_name=='image_resize_and_center_crop': return initialize_image_resize_and_center_crop_transform(config, dataset) - elif transform_name=='poverty_train': - return initialize_poverty_train_transform() + elif transform_name=='poverty': + return initialize_poverty_transform(is_training) elif transform_name=='rxrx1': - return initialize_rxrx1_transform(dataset) + return initialize_rxrx1_transform(is_training) else: raise ValueError(f"{transform_name} not recognized") @@ -91,25 +91,26 @@ def initialize_image_resize_and_center_crop_transform(config, dataset): ]) return transform -def initialize_poverty_train_transform(): - transforms_ls = [ - transforms.ToPILImage(), - transforms.RandomHorizontalFlip(), - transforms.RandomVerticalFlip(), - transforms.ColorJitter(brightness=0.8, contrast=0.8, saturation=0.8, hue=0.1), - transforms.ToTensor()] - rgb_transform = transforms.Compose(transforms_ls) - - def transform_rgb(img): - # bgr to rgb and back to bgr - img[:3] = rgb_transform(img[:3][[2,1,0]])[[2,1,0]] - return img - transform = transforms.Lambda(lambda x: transform_rgb(x)) - return transform - - -def initialize_rxrx1_transform(dataset: str): +def initialize_poverty_transform(is_training): + if is_training: + transforms_ls = [ + transforms.ToPILImage(), + transforms.RandomHorizontalFlip(), + transforms.RandomVerticalFlip(), + transforms.ColorJitter(brightness=0.8, contrast=0.8, saturation=0.8, hue=0.1), + transforms.ToTensor()] + rgb_transform = transforms.Compose(transforms_ls) + + def transform_rgb(img): + # bgr to rgb and back to bgr + img[:3] = rgb_transform(img[:3][[2,1,0]])[[2,1,0]] + return img + transform = transforms.Lambda(lambda x: transform_rgb(x)) + return transform + else: + return None +def initialize_rxrx1_transform(is_training): def standardize(x: torch.Tensor) -> torch.Tensor: mean = x.mean(dim=(1, 2)) std = x.std(dim=(1, 2)) @@ -126,17 +127,16 @@ def random_d8(x: torch.Tensor) -> torch.Tensor: return x t_random_d8 = transforms.Lambda(lambda x: random_d8(x)) - if dataset == 'train': + if is_training: transforms_ls = [ t_random_d8, transforms.ToTensor(), t_standardize, ] - elif dataset == 'test': + else: transforms_ls = [ transforms.ToTensor(), t_standardize, ] transform = transforms.Compose(transforms_ls) - return transform From 215b686a3a6521707677fcc85c47500d7fc54b25 Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Wed, 21 Apr 2021 16:33:28 -0700 Subject: [PATCH 154/244] smaller resnet support and group evals --- examples/configs/model.py | 12 ++++++++++++ examples/configs/supported.py | 2 +- examples/models/initializer.py | 4 ++-- wilds/datasets/rxrx1_dataset.py | 23 ++++++----------------- 4 files changed, 21 insertions(+), 20 deletions(-) diff --git a/examples/configs/model.py b/examples/configs/model.py index 46714bbe..6635df92 100644 --- a/examples/configs/model.py +++ b/examples/configs/model.py @@ -26,6 +26,18 @@ }, 'target_resolution': (224, 224), }, + 'resnet18': { + 'model_kwargs':{ + 'pretrained':True, + }, + 'target_resolution': (224, 224), + }, + 'resnet34': { + 'model_kwargs':{ + 'pretrained':True, + }, + 'target_resolution': (224, 224), + }, 'resnet50': { 'model_kwargs':{ 'pretrained':True, diff --git a/examples/configs/supported.py b/examples/configs/supported.py index 384404b7..4396bb50 100644 --- a/examples/configs/supported.py +++ b/examples/configs/supported.py @@ -29,7 +29,7 @@ # see initialize_*() functions for correspondence transforms = ['bert', 'image_base', 'image_resize_and_center_crop', 'poverty', 'rxrx1'] -models = ['resnet18_ms', 'resnet50', 'resnet34', 'wideresnet50', +models = ['resnet18_ms', 'resnet50', 'resnet34', 'resnet18', 'wideresnet50', 'densenet121', 'bert-base-uncased', 'distilbert-base-uncased', 'gin-virtual', 'logistic_regression', 'code-gpt-py'] algorithms = ['ERM', 'groupDRO', 'deepCORAL', 'IRM'] diff --git a/examples/models/initializer.py b/examples/models/initializer.py index 4d414763..2f194c68 100644 --- a/examples/models/initializer.py +++ b/examples/models/initializer.py @@ -23,7 +23,7 @@ def initialize_model(config, d_out, is_featurizer=False): If is_featurizer=False: - model: a model that is equivalent to nn.Sequential(featurizer, classifier) """ - if config.model in ('resnet50', 'resnet34', 'wideresnet50', 'densenet121'): + if config.model in ('resnet50', 'resnet34', 'resnet18', 'wideresnet50', 'densenet121'): if is_featurizer: featurizer = initialize_torchvision_model( name=config.model, @@ -105,7 +105,7 @@ def initialize_torchvision_model(name, d_out, **kwargs): elif name == 'densenet121': constructor_name = name last_layer_name = 'classifier' - elif name in ('resnet50', 'resnet34'): + elif name in ('resnet50', 'resnet34', 'resnet18'): constructor_name = name last_layer_name = 'fc' else: diff --git a/wilds/datasets/rxrx1_dataset.py b/wilds/datasets/rxrx1_dataset.py index d83b72d4..24cc39c5 100644 --- a/wilds/datasets/rxrx1_dataset.py +++ b/wilds/datasets/rxrx1_dataset.py @@ -113,7 +113,7 @@ def create_filepath(row): # eval grouper self._eval_grouper = CombinatorialGrouper( dataset=self, - groupby_fields=(['cell_type', 'experiment']) + groupby_fields=(['cell_type']) ) super().__init__(root_dir, download, split_scheme) @@ -132,22 +132,11 @@ def eval(self, y_pred, y_true, metadata, prediction_fn=None): - results (dictionary): Dictionary of evaluation metrics - results_str (str): String summarizing the evaluation metrics """ - metrics = [ - Accuracy(prediction_fn=prediction_fn), - ] - - results = {} - - for i in range(len(metrics)): - results.update({ - **metrics[i].compute(y_pred, y_true), - }) - - results_str = ( - f"Average acc: {results[metrics[0].agg_metric_field]:.3f}\n" - ) - - return results, results_str + metric = Accuracy(prediction_fn=prediction_fn) + return self.standard_group_eval( + metric, + self._eval_grouper, + y_pred, y_true, metadata) def get_input(self, idx): """ From 285561e0d7d99b36a3a2f71d134acda8f0b81cea Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Thu, 22 Apr 2021 12:14:57 -0700 Subject: [PATCH 155/244] remove default args to data processing --- wilds/datasets/encodetfbs_dataset.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index c4fee9cf..0b5c56fb 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -70,8 +70,8 @@ def wrap_anchor( def dnase_normalize( input_bw_celltype, ref_celltypes, - out_fname = 'norm', - data_pfx = '/users/abalsubr/wilds/examples/data/encode-tfbs_v1.0/' + out_fname, + data_pfx ): if not data_pfx.endswith('/'): data_pfx = data_pfx + '/' @@ -249,7 +249,8 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' 'val': 'Validation (OOD)', 'test': 'Test', } - # Add new split scheme specifying custom test and val celltypes in the format val..test., e.g. 'official' is 'val.HepG2.test.liver' + + # Add new split scheme specifying custom test and val celltypes in the format val..test., e.g. 'official' is 'val.HepG2.test.liver' elif '.' in self._split_scheme: all_celltypes = train_celltypes + val_celltype + test_celltype in_val_ct = self._split_scheme.split('.')[1] @@ -334,7 +335,7 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' self._seq_bp[chrom] = seq_arr[chrom] print(chrom, time.time() - itime) del seq_arr - + # Set up file handles for DNase features, writing normalized DNase tracks along the way. self._dnase_allcelltypes = {} for ct in self._all_celltypes: @@ -377,7 +378,7 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' self._metric = MultiTaskAveragePrecision() super().__init__(root_dir, download, split_scheme) - + def get_input(self, idx, window_size=12800): """ Returns x for a given idx in metadata_array, which has been filtered to only take windows with the desired stride. @@ -394,7 +395,7 @@ def get_input(self, idx, window_size=12800): seq_this = self._seq_bp[this_metadata['chr']][interval_start:interval_end] dnase_bw = self._dnase_allcelltypes[this_metadata['celltype']] dnase_this = np.nan_to_num(dnase_bw.values(chrom, interval_start, interval_end, numpy=True)) - + # assert(np.isnan(seq_this).sum() == 0) # assert(np.isnan(dnase_this).sum() == 0) # dnase_this = self.norm_signal(dnase_this, this_metadata['celltype']) From f47ba26a5944455815d5b9df1b66df4825cb4a8f Mon Sep 17 00:00:00 2001 From: aikanor Date: Thu, 22 Apr 2021 12:33:10 -0700 Subject: [PATCH 156/244] fix splits --- wilds/datasets/encodetfbs_dataset.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index c4fee9cf..e4162685 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -149,8 +149,6 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' # For our purposes, we will ignore these ambiguous labels during training and eval. self.y_array[self.y_array == 0.5] = float('nan') - dnase_norm_mode = 'norm' - # Construct splits train_chroms = ['chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr10', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX'] val_chroms = ['chr2', 'chr9', 'chr11'] @@ -197,7 +195,6 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' 'id_test': 'Test (ID)', } elif self._split_scheme == 'in-dist': - dnase_norm_mode = 'norm_id' splits = { 'train': { 'chroms': train_chroms, @@ -224,7 +221,6 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' } elif 'id-' in self._split_scheme: test_celltype = [ self._split_scheme.split('id-')[1] ] - dnase_norm_mode = 'norm_id' splits = { 'train': { 'chroms': train_chroms, @@ -274,18 +270,24 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' 'chroms': test_chroms, 'celltypes': test_celltype }, + 'id_test': { + 'chroms': test_chroms, + 'celltypes': train_celltypes + } } self._split_dict = { 'train': 0, 'val': 1, 'test': 2, 'id_val': 3, + 'id_test': 4 } self._split_names = { 'train': 'Train', 'val': 'Validation (OOD)', 'test': 'Test', 'id_val': 'Validation (ID)', + 'id_test': 'Test (ID)', } else: raise ValueError(f'Split scheme {self._split_scheme} not recognized') From 4db3f61217f9a6d5b611661729eeac1a4a338bf8 Mon Sep 17 00:00:00 2001 From: Berton Earnshaw Date: Thu, 22 Apr 2021 15:53:47 -0600 Subject: [PATCH 157/244] Add in-dist split scheme --- wilds/datasets/rxrx1_dataset.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/wilds/datasets/rxrx1_dataset.py b/wilds/datasets/rxrx1_dataset.py index d83b72d4..669248f4 100644 --- a/wilds/datasets/rxrx1_dataset.py +++ b/wilds/datasets/rxrx1_dataset.py @@ -61,7 +61,7 @@ def __init__(self, version=None, root_dir='data', download=False, self._version = version self._split_scheme = split_scheme - if self._split_scheme != 'official': + if self._split_scheme not in ['official', 'in-dist']: raise ValueError(f'Split scheme {self._split_scheme} not recognized') # path @@ -71,9 +71,15 @@ def __init__(self, version=None, root_dir='data', download=False, df = pd.read_csv(self._data_dir / 'metadata.csv') # Splits - self._split_dict = {'train': 0, 'val': 1, 'test': 2} - self._split_names = {'train': 'Train', 'val': 'Validation', 'test': 'Test'} - self._split_array = df.dataset.apply(self._split_dict.get).values + if split_scheme == 'official': + self._split_dict = {'train': 0, 'val': 1, 'test': 2} + self._split_names = {'train': 'Train', 'val': 'Validation', 'test': 'Test'} + self._split_array = df.dataset.apply(self._split_dict.get).values + elif split_scheme == 'in-dist': + df = df.query('dataset == "train"') + self._split_dict = {'train': 1, 'test': 2} + self._split_names = {'train': 'Train', 'test': 'Test'} + self._split_array = df.site.values # Filenames def create_filepath(row): From 51ced2e91fcde6d0a3a43ba99e81b180bd06dc63 Mon Sep 17 00:00:00 2001 From: Berton Earnshaw Date: Thu, 22 Apr 2021 17:10:15 -0600 Subject: [PATCH 158/244] Change in-dist split scheme --- wilds/datasets/rxrx1_dataset.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/wilds/datasets/rxrx1_dataset.py b/wilds/datasets/rxrx1_dataset.py index 669248f4..cdac4fab 100644 --- a/wilds/datasets/rxrx1_dataset.py +++ b/wilds/datasets/rxrx1_dataset.py @@ -76,10 +76,12 @@ def __init__(self, version=None, root_dir='data', download=False, self._split_names = {'train': 'Train', 'val': 'Validation', 'test': 'Test'} self._split_array = df.dataset.apply(self._split_dict.get).values elif split_scheme == 'in-dist': - df = df.query('dataset == "train"') - self._split_dict = {'train': 1, 'test': 2} - self._split_names = {'train': 'Train', 'test': 'Test'} - self._split_array = df.site.values + self._split_dict = {'train': 0, 'val': 1, 'test': 2, 'id-test': 3} + self._split_names = {'train': 'Train', 'val': 'Validation', 'test': 'Test', 'id-test': 'In-Distribution Test'} + self._split_array = df.dataset.apply(self._split_dict.get).values + # id-test set + mask = ((df.dataset == "train") & (df.site == 2)).values + self._split_array = np.where(mask, 3, self._split_array) # Filenames def create_filepath(row): From 947c855df2fb3820872c4a57d292a12aada09737 Mon Sep 17 00:00:00 2001 From: Etienne DAVID Date: Mon, 26 Apr 2021 13:45:24 +0200 Subject: [PATCH 159/244] latest baseline --- .gitignore | 3 ++- examples/models/detection/fasterrcnn.py | 10 ++++++---- wilds/common/metrics/all_metrics.py | 2 +- wilds/datasets/gwhd_dataset.py | 4 ++-- 4 files changed, 11 insertions(+), 8 deletions(-) diff --git a/.gitignore b/.gitignore index 02035664..e3b37d5a 100644 --- a/.gitignore +++ b/.gitignore @@ -8,4 +8,5 @@ test_faster paper* .vscode *sh -*ipynb \ No newline at end of file +*ipynb +experiences \ No newline at end of file diff --git a/examples/models/detection/fasterrcnn.py b/examples/models/detection/fasterrcnn.py index d577090b..3a9724a9 100644 --- a/examples/models/detection/fasterrcnn.py +++ b/examples/models/detection/fasterrcnn.py @@ -144,6 +144,7 @@ def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets) )) return torch.stack(objectness_loss), torch.stack(box_loss) + def forward(self, images, # type: ImageList features, # type: Dict[str, Tensor] @@ -183,8 +184,8 @@ def forward(self, # the proposals proposals = self.box_coder.decode(pred_bbox_deltas.detach(), anchors) proposals = proposals.view(num_images, -1, 4) - boxes, scores = self.filter_proposals(proposals, objectness, images.image_sizes, num_anchors_per_level) + boxes, scores = self.filter_proposals(proposals, objectness, images.image_sizes, num_anchors_per_level) losses = {} if self.training: @@ -214,7 +215,6 @@ def fastrcnn_loss(class_logits, box_regression, labels, regression_targets): box_loss (Tensor) """ - class_logits = torch.split(class_logits, 512,dim=0) box_regression = torch.split(box_regression, 512,dim=0) @@ -222,7 +222,6 @@ def fastrcnn_loss(class_logits, box_regression, labels, regression_targets): box_loss = [] for class_logits_, box_regression_, labels_, regression_targets_ in zip(class_logits, box_regression, labels, regression_targets): - classification_loss.append(F.cross_entropy(class_logits_, labels_)) # get indices that correspond to the regression targets for # the corresponding ground truth labels, to be used with @@ -287,16 +286,19 @@ def forward(self, regression_targets = None matched_idxs = None + + box_features = self.box_roi_pool(features, proposals, image_shapes) + box_features = self.box_head(box_features) class_logits, box_regression = self.box_predictor(box_features) - result = torch.jit.annotate(List[Dict[str, torch.Tensor]], []) losses = {} if self.training: assert labels is not None and regression_targets is not None + loss_classifier, loss_box_reg = fastrcnn_loss( class_logits, box_regression, labels, regression_targets) losses = { diff --git a/wilds/common/metrics/all_metrics.py b/wilds/common/metrics/all_metrics.py index 16172033..8989a3d9 100644 --- a/wilds/common/metrics/all_metrics.py +++ b/wilds/common/metrics/all_metrics.py @@ -190,7 +190,7 @@ def _compute_element_wise(self, y_pred ,y_true ): #target_scores = F.softmax(target_logits, dim=1)[..., 0] pred_boxes = target_boxes[target_scores > self.score_threshold] - det_accuracy = torch.mean(torch.stack([ self._accuracy(src_boxes["boxes"],pred_boxes,iou_thr) for iou_thr in np.arange(0.5,0.76,0.05)])) + det_accuracy = torch.mean(torch.stack([ self._accuracy(src_boxes["boxes"],pred_boxes,iou_thr) for iou_thr in np.arange(0.5,0.51,0.05)])) batch_results.append(det_accuracy) return torch.tensor(batch_results) diff --git a/wilds/datasets/gwhd_dataset.py b/wilds/datasets/gwhd_dataset.py index b2b98e7e..c0e7414c 100644 --- a/wilds/datasets/gwhd_dataset.py +++ b/wilds/datasets/gwhd_dataset.py @@ -56,7 +56,7 @@ class GWHDDataset(WILDSDataset): _dataset_name = 'gwhd' _versions_dict = { - '1.0': { + '2.0': { 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x42fa9775eacc453489a428abd59a437d/contents/blob/', 'compressed_size': None}} @@ -101,7 +101,7 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' labels = [{ "boxes": torch.stack([ - torch.tensor([int(i) for i in box.split(" ")]) + torch.tensor([int(float(i)) for i in box.split(" ")]) for box in boxes.split(";") ]), "labels": torch.tensor([1]*len(list(boxes.split(";")))).long() From cda5af482c0dfabb17a05dc59e5bcfb7e8495df2 Mon Sep 17 00:00:00 2001 From: aikanor Date: Tue, 27 Apr 2021 17:41:00 -0700 Subject: [PATCH 160/244] change preprocessing for REST/JUND --- dataset_preprocessing/encode-tfbs/README.md | 12 ++++++--- .../encode-tfbs/prep_accessibility.py | 4 ++- .../encode-tfbs/prep_metadata_labels.py | 26 +++++++++++++------ examples/models/CNN_genome.py | 20 ++++++++------ 4 files changed, 42 insertions(+), 20 deletions(-) diff --git a/dataset_preprocessing/encode-tfbs/README.md b/dataset_preprocessing/encode-tfbs/README.md index e0e39cc1..1ef362b6 100644 --- a/dataset_preprocessing/encode-tfbs/README.md +++ b/dataset_preprocessing/encode-tfbs/README.md @@ -3,7 +3,7 @@ #### Requirements - pyBigWig -#### Instructions +#### Instructions to create Codalab bundle 1. Download the human genome sequence (hg19 assembly) in FASTA format from http://hgdownload.cse.ucsc.edu/goldenpath/hg19/bigZips/hg19.fa.gz and extract it into `SEQUENCE_PATH`. @@ -11,10 +11,16 @@ 3. Download the DNase accessibility data. This consists of whole-genome DNase files in bigwig format from https://guanfiles.dcmb.med.umich.edu/Leopard/dnase_bigwig/. These are saved with filename `DNASE..fc.signal.bigwig`. -4. Download the labels from the challenge into a label directory `labels/` created for this purpose: +4. Run `prep_accessibility.py`. + +5. Download the labels from the challenge into a label directory `labels/` created for this purpose: - The training chromosome labels for the challenge's training cell types from https://www.synapse.org/#!Synapse:syn7413983 for the relevant transcription factor ( https://www.synapse.org/#!Synapse:syn7415202 for the TF MAX, , downloaded as MAX.train.labels.tsv.gz ). - The training chromosome labels for the challenge's evaluation cell type (liver) from https://www.synapse.org/#!Synapse:syn8077511 for the relevant transcription factor ( https://www.synapse.org/#!Synapse:syn8077648 for the TF MAX, downloaded as MAX.train_wc.labels.tsv.gz ). - The validation chromosome labels for the challenge's training cell types from https://www.synapse.org/#!Synapse:syn8441154 for the relevant transcription factor ( https://www.synapse.org/#!Synapse:syn8442103 for the TF MAX, downloaded as MAX.val.labels.tsv.gz ). - The validation chromosome labels for the challenge's evaluation cell type (liver) from https://www.synapse.org/#!Synapse:syn8442975 for the relevant transcription factor ( https://www.synapse.org/#!Synapse:syn8443021 for the TF MAX, downloaded as MAX.test.labels.tsv.gz ). -5. Run `prep_metadata_labels.py`. +6. Run `prep_metadata_labels.py`. + + +#### Instructions to run on Codalab bundle +7. \ No newline at end of file diff --git a/dataset_preprocessing/encode-tfbs/prep_accessibility.py b/dataset_preprocessing/encode-tfbs/prep_accessibility.py index 5716998d..65d66341 100644 --- a/dataset_preprocessing/encode-tfbs/prep_accessibility.py +++ b/dataset_preprocessing/encode-tfbs/prep_accessibility.py @@ -169,10 +169,12 @@ def generate_accessibility_archives(input_dir='dnase_bigwigs', output_dir='codal all_celltypes = ch_train_celltypes + ch_val_celltype + ch_test_celltype for ct in all_celltypes: qn_sample_to_array([ct], input_chroms=train_chroms) - + + """ # Create normalized bigwigs for OOD validation split. for ct in all_celltypes: dnase_normalize(ct, ref_celltypes) # Create normalized bigwig for ID validation split. for ct in ch_test_celltype: dnase_normalize(ct, ch_test_celltype, out_fname = 'norm_id') + """ diff --git a/dataset_preprocessing/encode-tfbs/prep_metadata_labels.py b/dataset_preprocessing/encode-tfbs/prep_metadata_labels.py index 3a002c45..00a8b5c3 100644 --- a/dataset_preprocessing/encode-tfbs/prep_metadata_labels.py +++ b/dataset_preprocessing/encode-tfbs/prep_metadata_labels.py @@ -15,10 +15,10 @@ def write_label_bigwigs( celltypes, train_suffix='train.labels.tsv.gz', - val_suffix='val.labels.tsv.gz' + val_suffix='val.labels.tsv.gz', + tf_name='MAX' ): itime = time.time() - tf_name = 'MAX' # Read in metadata dataframe from training+validation data train_regions_labeled = pd.read_csv(os.path.join(_data_dir, 'labels/{}.{}'.format(tf_name, train_suffix)), sep='\t') @@ -73,11 +73,14 @@ def write_label_bigwigs( def write_metadata_products( - celltypes, bed_df_filename='metadata_df.bed', y_arr_filename='metadata_y.npy', - stride=6400, posamb_only=False + celltypes, + bed_df_filename='metadata_df.bed', + y_arr_filename='metadata_y.npy', + stride=6400, + tf_name='MAX', + posamb_only=False ): itime = time.time() - tf_name = 'MAX' celltype_mdta = [] celltype_labels = [] if posamb_only: @@ -131,6 +134,13 @@ def write_metadata_products( if __name__ == '__main__': - _all_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562', 'A549', 'GM12878', 'liver'] - write_label_bigwigs(_all_celltypes) - write_metadata_products(_all_celltypes) + tf_name = 'JUND' + tfs_to_celltypes = { + 'MAX': ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562', 'A549', 'GM12878', 'liver'], + 'REST': ['H1-hESC', 'HeLa-S3', 'HepG2', 'MCF-7', 'Panc1', 'liver'], + 'JUND': ['HCT116', 'HeLa-S3', 'HepG2', 'K562', 'MCF-7', 'liver'] + } + all_celltypes = tfs_to_celltypes[tf_name] + write_label_bigwigs([x for x in all_celltypes if x != 'liver'], tf_name=tf_name) + write_label_bigwigs(['liver'], train_suffix='train_wc.labels.tsv.gz', val_suffix='test.labels.tsv.gz', tf_name=tf_name) + write_metadata_products(all_celltypes, tf_name=tf_name) diff --git a/examples/models/CNN_genome.py b/examples/models/CNN_genome.py index a5711a48..6d3f7d0d 100644 --- a/examples/models/CNN_genome.py +++ b/examples/models/CNN_genome.py @@ -5,20 +5,21 @@ import torch.nn.functional as F - -def single_conv(in_channels, out_channels): +def single_conv(in_channels, out_channels, kernel_size=7): + padding_size = int((kernel_size-1)/2) return nn.Sequential( - nn.Conv1d(in_channels, out_channels, 7, padding=3), + nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding_size), nn.BatchNorm1d(out_channels), nn.ReLU(inplace=True) ) -def double_conv(in_channels, out_channels): +def double_conv(in_channels, out_channels, kernel_size=7): + padding_size = int((kernel_size-1)/2) return nn.Sequential( - nn.Conv1d(in_channels, out_channels, 7, padding=3), + nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding_size), nn.BatchNorm1d(out_channels), nn.ReLU(inplace=True), - nn.Conv1d(out_channels, out_channels, 7, padding=3), + nn.Conv1d(out_channels, out_channels, kernel_size, padding=padding_size), nn.BatchNorm1d(out_channels), nn.ReLU(inplace=True) ) @@ -54,6 +55,8 @@ def __init__(self, num_tasks=16, n_channels_in=5): self.conv_last = nn.Conv1d(15, 1, 200, stride=50, padding=0) self.d_out = num_tasks if num_tasks is not None else 253 + + self.fc_last = nn.Linear(253, 128) def forward(self, x): @@ -107,7 +110,8 @@ def forward(self, x): # Default input_size == 12800: x has size N x 1 x 253 at this point. if self.d_out == 253: out = x - else: # middle 128 values - out = x[:, 64:192] + else: + out = self.fc_last(x) + # out = x[:, 64:192] # middle 128 values return out From b34249ba285dedece536e8f142cffe8474d4f78f Mon Sep 17 00:00:00 2001 From: aikanor Date: Wed, 28 Apr 2021 09:07:34 -0700 Subject: [PATCH 161/244] encoding TF in split_scheme --- wilds/datasets/encodetfbs_dataset.py | 62 ++++++++++++++++++---------- 1 file changed, 40 insertions(+), 22 deletions(-) diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index 324bc5e4..8bbe1fbf 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -132,31 +132,35 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' self._version = version self._data_dir = self.initialize_data_dir(root_dir, download) self._y_size = 128 - self._transcription_factor = 'MAX' - - # Read in metadata and labels - self._metadata_df = pd.read_csv( - self._data_dir + '/labels/{}/metadata_df.bed'.format(self._transcription_factor), - sep='\t', header=None, - index_col=None, names=['chr', 'start', 'stop', 'celltype'] - ) - self._y_array = torch.tensor(np.load( - self._data_dir + '/labels/{}/metadata_y.npy'.format(self._transcription_factor))) - - # ~10% of the dataset has ambiguous labels - # i.e., we can't tell if there is a binding event or not. - # This typically happens at the flanking regions of peaks. - # For our purposes, we will ignore these ambiguous labels during training and eval. - self.y_array[self.y_array == 0.5] = float('nan') # Construct splits train_chroms = ['chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr10', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX'] val_chroms = ['chr2', 'chr9', 'chr11'] test_chroms = ['chr1', 'chr8', 'chr21'] - train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'K562', 'A549', 'GM12878'] - val_celltype = ['HepG2'] - test_celltype = ['liver'] + official_train_cts = { + 'MAX': ['H1-hESC', 'HCT116', 'HeLa-S3', 'K562', 'A549', 'GM12878'], + 'REST': ['H1-hESC', 'HeLa-S3', 'MCF-7', 'Panc1'], + 'JUND': ['HCT116', 'HeLa-S3', 'K562', 'MCF-7'] + } + official_val_cts = { + 'MAX': ['HepG2'], 'REST': ['HepG2'], 'JUND': ['HepG2'] + } + official_test_cts = { + 'MAX': ['liver'], 'REST': ['liver'], 'JUND': ['liver'] + } + + # Set the TF in split_scheme by prefacing it with 'tf..' + self._transcription_factor = 'MAX' + if 'tf.' in split_scheme: + tkns = split_scheme.split('.') + self._transcription_factor = tkns[1] + split_scheme = '.'.join(tkns[2:]) self._split_scheme = split_scheme + + train_celltypes = official_train_cts[self._transcription_factor] + val_celltype = official_val_cts[self._transcription_factor] + test_celltype = official_test_cts[self._transcription_factor] + if self._split_scheme == 'official': splits = { 'train': { @@ -246,7 +250,7 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' 'test': 'Test', } - # Add new split scheme specifying custom test and val celltypes in the format val..test., e.g. 'official' is 'val.HepG2.test.liver' + # Add new split scheme specifying custom test and val celltypes in the format val..test., e.g. 'official' is 'tf.MAX.val.HepG2.test.liver' elif '.' in self._split_scheme: all_celltypes = train_celltypes + val_celltype + test_celltype in_val_ct = self._split_scheme.split('.')[1] @@ -293,6 +297,21 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' else: raise ValueError(f'Split scheme {self._split_scheme} not recognized') + # Read in metadata and labels + self._metadata_df = pd.read_csv( + self._data_dir + '/labels/{}/metadata_df.bed'.format(self._transcription_factor), + sep='\t', header=None, + index_col=None, names=['chr', 'start', 'stop', 'celltype'] + ) + self._y_array = torch.tensor(np.load( + self._data_dir + '/labels/{}/metadata_y.npy'.format(self._transcription_factor))) + + # ~10% of the dataset has ambiguous labels + # i.e., we can't tell if there is a binding event or not. + # This typically happens at the flanking regions of peaks. + # For our purposes, we will ignore these ambiguous labels during training and eval. + self.y_array[self.y_array == 0.5] = float('nan') + self._split_array = -1 * np.ones(self._metadata_df.shape[0]).astype(int) for split, d in splits.items(): chrom_mask = np.isin(self._metadata_df['chr'], d['chroms']) @@ -309,7 +328,7 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' # Subsample the testing and validation indices, to speed up evaluation. # For the OOD splits (val and test), we subsample by a factor of 3 - # For the id_val split if it exists, we subsample by a factor of 15 + # For the id_val split if it exists, we subsample by a factor of 3*(# of training celltypes) for subsample_seed, (split, subsample_factor) in enumerate( [('val', 3), ('test', 3), ('id_val', 3*len(splits['train']['celltypes'])) ]): if split not in self._split_dict: continue @@ -397,7 +416,6 @@ def get_input(self, idx, window_size=12800): seq_this = self._seq_bp[this_metadata['chr']][interval_start:interval_end] dnase_bw = self._dnase_allcelltypes[this_metadata['celltype']] dnase_this = np.nan_to_num(dnase_bw.values(chrom, interval_start, interval_end, numpy=True)) - # assert(np.isnan(seq_this).sum() == 0) # assert(np.isnan(dnase_this).sum() == 0) # dnase_this = self.norm_signal(dnase_this, this_metadata['celltype']) From a729278275e582ab24d79f16ae835e4e1d95a1cd Mon Sep 17 00:00:00 2001 From: aikanor Date: Wed, 28 Apr 2021 09:11:03 -0700 Subject: [PATCH 162/244] encoding TF in split_scheme 2/2 --- wilds/datasets/encodetfbs_dataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index 8bbe1fbf..7395d310 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -124,7 +124,8 @@ class EncodeTFBSDataset(WILDSDataset): _dataset_name = 'encode-tfbs' _versions_dict = { '1.0': { - 'download_url': 'https://worksheets.codalab.org/rest/bundles/0xf0a83ce649c540b39149250dc8e3c66b/contents/blob/', + #'download_url': 'https://worksheets.codalab.org/rest/bundles/0xf0a83ce649c540b39149250dc8e3c66b/contents/blob/', + 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x9c282b6e9082440f9dcd61bb605c1eab/contents/blob/', 'compressed_size': None}} def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'): From aae22c183b5868668700342a25beeb201a28cd09 Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Thu, 29 Apr 2021 12:38:36 -0700 Subject: [PATCH 163/244] change multiprocessing sharing strategy for gwhd --- examples/run_expt.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/examples/run_expt.py b/examples/run_expt.py index 34edbd66..e5eb2fa2 100644 --- a/examples/run_expt.py +++ b/examples/run_expt.py @@ -22,7 +22,6 @@ import torch.multiprocessing def main(): - torch.multiprocessing.set_sharing_strategy('file_system') ''' to see default hyperparams for each dataset/model, look at configs/ ''' parser = argparse.ArgumentParser() @@ -118,10 +117,15 @@ def main(): config = parser.parse_args() config = populate_defaults(config) - # set device + # For the GWHD dataset, we need to change the multiprocessing strategy or there will be + # too many open file descriptors + if config.dataset == 'gwhd': + torch.multiprocessing.set_sharing_strategy('file_system') + + # Set device config.device = torch.device("cuda:" + str(config.device)) if torch.cuda.is_available() else torch.device("cpu") - ## Initialize logs + # Initialize logs if os.path.exists(config.log_dir) and config.resume: resume=True mode='a' From b1c72cf5759e63e36952302ff2ea00ae92d32517 Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Thu, 29 Apr 2021 12:47:40 -0700 Subject: [PATCH 164/244] strip out DETR --- examples/configs/datasets.py | 2 +- examples/configs/model.py | 31 -- examples/configs/supported.py | 6 +- examples/losses.py | 43 +- examples/models/detr/README.md | 1 - examples/models/detr/__init__.py | 6 - examples/models/detr/backbone.py | 119 ------ examples/models/detr/detr.py | 348 ---------------- examples/models/detr/matcher.py | 86 ---- examples/models/detr/position_encoding.py | 89 ----- examples/models/detr/transformer.py | 296 -------------- examples/models/detr/util/__init__.py | 1 - examples/models/detr/util/box_ops.py | 88 ---- examples/models/detr/util/misc.py | 467 ---------------------- examples/models/detr/util/plot_utils.py | 107 ----- examples/models/initializer.py | 61 +-- wilds/datasets/gwhd_dataset.py | 17 - 17 files changed, 7 insertions(+), 1761 deletions(-) delete mode 100644 examples/models/detr/README.md delete mode 100644 examples/models/detr/__init__.py delete mode 100644 examples/models/detr/backbone.py delete mode 100644 examples/models/detr/detr.py delete mode 100644 examples/models/detr/matcher.py delete mode 100644 examples/models/detr/position_encoding.py delete mode 100644 examples/models/detr/transformer.py delete mode 100644 examples/models/detr/util/__init__.py delete mode 100644 examples/models/detr/util/box_ops.py delete mode 100644 examples/models/detr/util/misc.py delete mode 100644 examples/models/detr/util/plot_utils.py diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index 92bd1ab3..09944562 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -290,7 +290,7 @@ 'model_kwargs': { 'n_classes': 1, 'pretrained': True}, - 'loss_function': 'faster_criterion', + 'loss_function': 'fasterrcnn_criterion', 'groupby_fields': ['location'], 'val_metric': 'detection_accuracy_avg', # TODO 'val_metric_decreasing': False, diff --git a/examples/configs/model.py b/examples/configs/model.py index 6e1a870f..37afbf22 100644 --- a/examples/configs/model.py +++ b/examples/configs/model.py @@ -37,39 +37,8 @@ 'target_resolution': (224, 224), }, 'logistic_regression': {}, - 'detr': { - 'max_grad_norm': 0.1, - 'model_kwargs': { - # Backbone. Always uses sine position embedding. - 'train_backbone': True, - 'backbone': 'resnet50', - 'dilation': False, - # Transformer - 'enc_layers': 6, - 'dec_layers': 6, - 'dim_feedforward': 2048, - 'hidden_dim': 256, - 'dropout': 0.1, - 'nheads': 8, - 'pre_norm': False, - }, - 'loss_kwargs': { - # Matcher - 'set_cost_class': 1, - 'set_cost_bbox': 5, - 'set_cost_giou': 2, - # Loss - 'mask_loss_coef': 1, - 'dice_loss_coef': 1, - 'bbox_loss_coef': 5, - 'giou_loss_coef': 2, - # 'eos_coef': 0.1, - 'eos_coef': 0.5, - } - }, 'fasterrcnn': { 'model_kwargs': { - # Backbone. Always uses sine position embedding. 'pretrained': True, } } diff --git a/examples/configs/supported.py b/examples/configs/supported.py index 32fac62f..b0fd9aab 100644 --- a/examples/configs/supported.py +++ b/examples/configs/supported.py @@ -1,6 +1,5 @@ # metrics from wilds.common.metrics.all_metrics import Accuracy, MultiTaskAccuracy, MSE, multiclass_logits_to_pred, binary_logits_to_pred -from utils import remove_key algo_log_metrics = { 'accuracy': Accuracy(prediction_fn=multiclass_logits_to_pred), @@ -13,7 +12,6 @@ process_outputs_functions = { 'binary_logits_to_pred': binary_logits_to_pred, 'multiclass_logits_to_pred': multiclass_logits_to_pred, - 'remove_detr_aux_outputs': remove_key('aux_outputs'), None: None, } @@ -21,7 +19,7 @@ models = ['resnet18_ms', 'resnet50', 'resnet34', 'wideresnet50', 'densenet121', 'bert-base-uncased', 'distilbert-base-uncased', 'gin-virtual', 'logistic_regression', 'code-gpt-py', - 'detr'] + 'fasterrcnn'] # See algorithms/initializer.py algorithms = ['ERM', 'groupDRO', 'deepCORAL', 'IRM'] @@ -36,4 +34,4 @@ transforms = ['bert', 'image_base', 'image_resize_and_center_crop', 'poverty_train'] # See losses.py -losses = ['cross_entropy', 'lm_cross_entropy', 'MSE', 'multitask_bce', 'detr_set_criterion'] +losses = ['cross_entropy', 'lm_cross_entropy', 'MSE', 'multitask_bce', 'fasterrcnn_criterion'] diff --git a/examples/losses.py b/examples/losses.py index 7fb1ae92..cfd0789e 100644 --- a/examples/losses.py +++ b/examples/losses.py @@ -14,46 +14,9 @@ def initialize_loss(config, d_out): elif config.loss_function == 'multitask_bce': return MultiTaskLoss(loss_fn=nn.BCEWithLogitsLoss(reduction='none')) - elif config.loss_function == 'detr_set_criterion': - return ElementwiseLoss(loss_fn=get_detr_set_criterion(config, d_out)) - elif config.loss_function == 'faster_criterion': - return ElementwiseLoss(loss_fn=get_faster_criterion(config)) + elif config.loss_function == 'fasterrcnn_criterion': + from examples.models.detection.fasterrcnn import FasterRCNNLoss + return ElementwiseLoss(loss_fn=FasterRCNNLoss(config.device)) else: raise ValueError(f'config.loss_function {config.loss_function} not recognized') - - -def get_faster_criterion(config): - from examples.models.detection.fasterrcnn import FasterRCNNLoss - - criterion = FasterRCNNLoss(config.device) - return criterion - - -def get_detr_set_criterion(config, d_out): - from examples.models.detr.matcher import HungarianMatcher - from examples.models.detr.detr import SetCriterion - - matcher = HungarianMatcher( - cost_class=config.loss_kwargs['set_cost_class'], - cost_bbox=config.loss_kwargs['set_cost_bbox'], - cost_giou=config.loss_kwargs['set_cost_giou']) - weight_dict = { - 'loss_ce': 1, - 'loss_bbox': config.loss_kwargs['bbox_loss_coef']} - weight_dict['loss_giou'] = config.loss_kwargs['giou_loss_coef'] - - if config.model_kwargs['aux_loss']: - aux_weight_dict = {} - for i in range(config.model_kwargs['dec_layers'] - 1): - aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()}) - weight_dict.update(aux_weight_dict) - - criterion = SetCriterion( - d_out, - matcher=matcher, - weight_dict=weight_dict, - eos_coef=config.loss_kwargs['eos_coef'], - losses=['labels', 'boxes', 'cardinality']).to(config.device) - - return criterion diff --git a/examples/models/detr/README.md b/examples/models/detr/README.md deleted file mode 100644 index 0be3336f..00000000 --- a/examples/models/detr/README.md +++ /dev/null @@ -1 +0,0 @@ -DETR is licensed under the [Apache License 2.0](https://github.com/facebookresearch/detr/blob/master/LICENSE). Code is adapted from the [DETR GitHub repository](https://github.com/facebookresearch/detr/). diff --git a/examples/models/detr/__init__.py b/examples/models/detr/__init__.py deleted file mode 100644 index 435bab61..00000000 --- a/examples/models/detr/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved -# from .detr import build - - -# def build_model(args): - # return build(args) diff --git a/examples/models/detr/backbone.py b/examples/models/detr/backbone.py deleted file mode 100644 index d03e8a5d..00000000 --- a/examples/models/detr/backbone.py +++ /dev/null @@ -1,119 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved -""" -Backbone modules. -""" -from collections import OrderedDict - -import torch -import torch.nn.functional as F -import torchvision -from torch import nn -from torchvision.models._utils import IntermediateLayerGetter -from typing import Dict, List - -from .util.misc import NestedTensor, is_main_process - -from .position_encoding import build_position_encoding - - -class FrozenBatchNorm2d(torch.nn.Module): - """ - BatchNorm2d where the batch statistics and the affine parameters are fixed. - - Copy-paste from torchvision.misc.ops with added eps before rqsrt, - without which any other models than torchvision.models.resnet[18,34,50,101] - produce nans. - """ - - def __init__(self, n): - super(FrozenBatchNorm2d, self).__init__() - self.register_buffer("weight", torch.ones(n)) - self.register_buffer("bias", torch.zeros(n)) - self.register_buffer("running_mean", torch.zeros(n)) - self.register_buffer("running_var", torch.ones(n)) - - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, - missing_keys, unexpected_keys, error_msgs): - num_batches_tracked_key = prefix + 'num_batches_tracked' - if num_batches_tracked_key in state_dict: - del state_dict[num_batches_tracked_key] - - super(FrozenBatchNorm2d, self)._load_from_state_dict( - state_dict, prefix, local_metadata, strict, - missing_keys, unexpected_keys, error_msgs) - - def forward(self, x): - # move reshapes to the beginning - # to make it fuser-friendly - w = self.weight.reshape(1, -1, 1, 1) - b = self.bias.reshape(1, -1, 1, 1) - rv = self.running_var.reshape(1, -1, 1, 1) - rm = self.running_mean.reshape(1, -1, 1, 1) - eps = 1e-5 - scale = w * (rv + eps).rsqrt() - bias = b - rm * scale - return x * scale + bias - - -class BackboneBase(nn.Module): - - def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool): - super().__init__() - for name, parameter in backbone.named_parameters(): - if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: - parameter.requires_grad_(False) - if return_interm_layers: - return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} - else: - return_layers = {'layer4': "0"} - self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) - self.num_channels = num_channels - - def forward(self, tensor_list: NestedTensor): - xs = self.body(tensor_list.tensors) - out: Dict[str, NestedTensor] = {} - for name, x in xs.items(): - m = tensor_list.mask - assert m is not None - mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] - out[name] = NestedTensor(x, mask) - return out - - -class Backbone(BackboneBase): - """ResNet backbone with frozen BatchNorm.""" - def __init__(self, name: str, - train_backbone: bool, - return_interm_layers: bool, - dilation: bool): - backbone = getattr(torchvision.models, name)( - replace_stride_with_dilation=[False, False, dilation], - pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) - num_channels = 512 if name in ('resnet18', 'resnet34') else 2048 - super().__init__(backbone, train_backbone, num_channels, return_interm_layers) - - -class Joiner(nn.Sequential): - def __init__(self, backbone, position_embedding): - super().__init__(backbone, position_embedding) - - def forward(self, tensor_list: NestedTensor): - xs = self[0](tensor_list) - out: List[NestedTensor] = [] - pos = [] - for name, x in xs.items(): - out.append(x) - # position encoding - pos.append(self[1](x).to(x.tensors.dtype)) - - return out, pos - - -def build_backbone(args): - position_embedding = build_position_encoding(args) - train_backbone = args.lr_backbone > 0 - return_interm_layers = args.masks - backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation) - model = Joiner(backbone, position_embedding) - model.num_channels = backbone.num_channels - return model diff --git a/examples/models/detr/detr.py b/examples/models/detr/detr.py deleted file mode 100644 index 061c5fcd..00000000 --- a/examples/models/detr/detr.py +++ /dev/null @@ -1,348 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved -""" -DETR model and criterion classes. -""" -import torch -import torch.nn.functional as F -from torch import nn - -from .util import box_ops -from .util.misc import (NestedTensor, nested_tensor_from_tensor_list, - accuracy, get_world_size, interpolate, - is_dist_avail_and_initialized) - -from .backbone import build_backbone -from .matcher import build_matcher -from .transformer import build_transformer - - -class DETR(nn.Module): - """ This is the DETR module that performs object detection """ - def __init__(self, backbone, transformer, num_classes, num_queries, aux_loss=False): - """ Initializes the model. - Parameters: - backbone: torch module of the backbone to be used. See backbone.py - transformer: torch module of the transformer architecture. See transformer.py - num_classes: number of object classes - num_queries: number of object queries, ie detection slot. This is the maximal number of objects - DETR can detect in a single image. For COCO, we recommend 100 queries. - aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. - """ - super().__init__() - self.num_queries = num_queries - self.transformer = transformer - hidden_dim = transformer.d_model - self.class_embed = nn.Linear(hidden_dim, num_classes + 1) - self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) - self.query_embed = nn.Embedding(num_queries, hidden_dim) - self.input_proj = nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=1) - self.backbone = backbone - self.aux_loss = aux_loss - - def forward(self, samples: NestedTensor): - """ The forward expects a NestedTensor, which consists of: - - samples.tensor: batched images, of shape [batch_size x 3 x H x W] - - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels - - It returns a dict with the following elements: - - "pred_logits": the classification logits (including no-object) for all queries. - Shape= [batch_size x num_queries x (num_classes + 1)] - - "pred_boxes": The normalized boxes coordinates for all queries, represented as - (center_x, center_y, height, width). These values are normalized in [0, 1], - relative to the size of each individual image (disregarding possible padding). - See PostProcess for information on how to retrieve the unnormalized bounding box. - - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of - dictionnaries containing the two above keys for each decoder layer. - """ - if isinstance(samples, (list, torch.Tensor)): - samples = nested_tensor_from_tensor_list(samples) - features, pos = self.backbone(samples) - - src, mask = features[-1].decompose() - assert mask is not None - hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0] - - outputs_class = self.class_embed(hs) - outputs_coord = self.bbox_embed(hs).sigmoid() - out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]} - if self.aux_loss: - out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord) - return out - - @torch.jit.unused - def _set_aux_loss(self, outputs_class, outputs_coord): - # this is a workaround to make torchscript happy, as torchscript - # doesn't support dictionary with non-homogeneous values, such - # as a dict having both a Tensor and a list. - return [{'pred_logits': a, 'pred_boxes': b} - for a, b in zip(outputs_class[:-1], outputs_coord[:-1])] - - -class SetCriterion(nn.Module): - """ This class computes the loss for DETR. - The process happens in two steps: - 1) we compute hungarian assignment between ground truth boxes and the outputs of the model - 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) - """ - def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses): - """ Create the criterion. - Parameters: - num_classes: number of object categories, omitting the special no-object category - matcher: module able to compute a matching between targets and proposals - weight_dict: dict containing as key the names of the losses and as values their relative weight. - eos_coef: relative classification weight applied to the no-object category - losses: list of all the losses to be applied. See get_loss for list of available losses. - """ - super().__init__() - self.num_classes = num_classes - self.matcher = matcher - self.weight_dict = weight_dict - self.eos_coef = eos_coef - self.losses = losses - empty_weight = torch.ones(self.num_classes + 1) - empty_weight[-1] = self.eos_coef - self.register_buffer('empty_weight', empty_weight) - - def loss_labels(self, outputs, targets, indices, num_boxes, log=True): - """Classification loss (NLL) - targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] - """ - assert 'pred_logits' in outputs - src_logits = outputs['pred_logits'] - - idx = self._get_src_permutation_idx(indices) - target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) - target_classes = torch.full(src_logits.shape[:2], self.num_classes, - dtype=torch.int64, device=src_logits.device) - target_classes[idx] = target_classes_o - - # loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight) - loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight, reduction='none') - losses = {'loss_ce': loss_ce.mean(dim=1)} - - if log: - # TODO this should probably be a separate loss, not hacked in this one here - losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0] - return losses - - @torch.no_grad() - def loss_cardinality(self, outputs, targets, indices, num_boxes): - """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes - This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients - """ - pred_logits = outputs['pred_logits'] - device = pred_logits.device - tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device) - # Count the number of predictions that are NOT "no-object" (which is the last class) - card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1) - card_err = F.l1_loss(card_pred.float(), tgt_lengths.float()) - losses = {'cardinality_error': card_err} - return losses - - def loss_boxes(self, outputs, targets, indices, num_boxes): - """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss - targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] - The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size. - """ - assert 'pred_boxes' in outputs - idx = self._get_src_permutation_idx(indices) - src_boxes = outputs['pred_boxes'][idx] - target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) - - loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none').sum(dim=1) - - loss_giou = 1 - torch.diag(box_ops.generalized_box_iou( - box_ops.box_cxcywh_to_xyxy(src_boxes), - box_ops.box_cxcywh_to_xyxy(target_boxes))) - - tgt_lengths = [len(v["labels"]) for v in targets] - - device = outputs['pred_logits'].device - losses = {} - losses['loss_bbox'] = torch.zeros(len(tgt_lengths), device=device) - losses['loss_giou'] = torch.zeros(len(tgt_lengths), device=device) - - pos = 0 - for i, tgt_length in enumerate(tgt_lengths): - if tgt_length == 0: - losses['loss_bbox'][i] = 0 - losses['loss_giou'][i] = 0 - else: - losses['loss_bbox'][i] = loss_bbox[pos:pos+tgt_length].mean() - losses['loss_giou'][i] = loss_giou[pos:pos+tgt_length].mean() - pos += tgt_length - - # losses['loss_bbox'] = loss_bbox.sum() / num_boxes - # losses['loss_giou'] = loss_giou.sum() / num_boxes - return losses - - def _get_src_permutation_idx(self, indices): - # permute predictions following indices - batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) - src_idx = torch.cat([src for (src, _) in indices]) - return batch_idx, src_idx - - def _get_tgt_permutation_idx(self, indices): - # permute targets following indices - batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) - tgt_idx = torch.cat([tgt for (_, tgt) in indices]) - return batch_idx, tgt_idx - - def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs): - loss_map = { - 'labels': self.loss_labels, - 'cardinality': self.loss_cardinality, - 'boxes': self.loss_boxes, - } - assert loss in loss_map, f'do you really want to compute {loss} loss?' - return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs) - - def forward(self, outputs, targets): - """ This performs the loss computation. - Parameters: - outputs: dict of tensors, see the output specification of the model for the format - targets: list of dicts, such that len(targets) == batch_size. - The expected keys in each dict depends on the losses applied, see each loss' doc - """ - outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'} - - # Retrieve the matching between the outputs of the last layer and the targets - indices = self.matcher(outputs_without_aux, targets) - - # Compute the average number of target boxes accross all nodes, for normalization purposes - # num_boxes = sum(len(t["labels"]) for t in targets) - # num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) - # if is_dist_avail_and_initialized(): - # torch.distributed.all_reduce(num_boxes) - # num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() - num_boxes = None - - # Compute all the requested losses - total_loss = 0 - losses = {} - for loss in self.losses: - losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes)) - - # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. - if 'aux_outputs' in outputs: - for i, aux_outputs in enumerate(outputs['aux_outputs']): - indices = self.matcher(aux_outputs, targets) - for loss in self.losses: - kwargs = {} - if loss == 'labels': - # Logging is enabled only for the last layer - kwargs = {'log': False} - l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs) - l_dict = {k + f'_{i}': v for k, v in l_dict.items()} - losses.update(l_dict) - - # Sum up weighted losses by element - device = outputs['pred_logits'].device - elementwise_loss = torch.zeros(len(outputs['pred_logits']), device=device) - - # print(f"Losses: class {losses['loss_ce'].detach().cpu().numpy()}, bbox {losses['loss_bbox'].detach().cpu().numpy()}, giou {losses['loss_giou'].detach().cpu().numpy()}") - - for k in self.weight_dict: - elementwise_loss += self.weight_dict[k] * losses[k] - - return elementwise_loss - - - - -class MLP(nn.Module): - """ Very simple multi-layer perceptron (also called FFN)""" - - def __init__(self, input_dim, hidden_dim, output_dim, num_layers): - super().__init__() - self.num_layers = num_layers - h = [hidden_dim] * (num_layers - 1) - self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) - - def forward(self, x): - for i, layer in enumerate(self.layers): - x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) - return x - - -# -# class PostProcess(nn.Module): -# """ This module converts the model's output into the format expected by the coco api""" -# @torch.no_grad() -# def forward(self, outputs, target_sizes): -# """ Perform the computation -# Parameters: -# outputs: raw outputs of the model -# target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch -# For evaluation, this must be the original image size (before any data augmentation) -# For visualization, this should be the image size after data augment, but before padding -# """ -# out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes'] -# -# assert len(out_logits) == len(target_sizes) -# assert target_sizes.shape[1] == 2 -# -# prob = F.softmax(out_logits, -1) -# scores, labels = prob[..., :-1].max(-1) -# -# # convert to [x0, y0, x1, y1] format -# boxes = box_ops.box_cxcywh_to_xyxy(out_bbox) -# # and from relative [0, 1] to absolute [0, height] coordinates -# img_h, img_w = target_sizes.unbind(1) -# scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) -# boxes = boxes * scale_fct[:, None, :] -# -# results = [{'scores': s, 'labels': l, 'boxes': b} for s, l, b in zip(scores, labels, boxes)] -# -# return results -# -# -# -# -# def build(args): -# # the `num_classes` naming here is somewhat misleading. -# # it indeed corresponds to `max_obj_id + 1`, where max_obj_id -# # is the maximum id for a class in your dataset. For example, -# # COCO has a max_obj_id of 90, so we pass `num_classes` to be 91. -# # As another example, for a dataset that has a single class with id 1, -# # you should pass `num_classes` to be 2 (max_obj_id + 1). -# # For more details on this, check the following discussion -# # https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223 -# num_classes = 20 if args.dataset_file != 'coco' else 91 -# if args.dataset_file == "coco_panoptic": -# # for panoptic, we just add a num_classes that is large enough to hold -# # max_obj_id + 1, but the exact value doesn't really matter -# num_classes = 250 -# device = torch.device(args.device) -# -# backbone = build_backbone(args) -# -# transformer = build_transformer(args) -# -# model = DETR( -# backbone, -# transformer, -# num_classes=num_classes, -# num_queries=args.num_queries, -# aux_loss=args.aux_loss, -# ) -# -# matcher = build_matcher(args) -# weight_dict = {'loss_ce': 1, 'loss_bbox': args.bbox_loss_coef} -# weight_dict['loss_giou'] = args.giou_loss_coef -# -# # TODO this is a hack -# if args.aux_loss: -# aux_weight_dict = {} -# for i in range(args.dec_layers - 1): -# aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()}) -# weight_dict.update(aux_weight_dict) -# -# losses = ['labels', 'boxes', 'cardinality'] -# -# criterion = SetCriterion(num_classes, matcher=matcher, weight_dict=weight_dict, -# eos_coef=args.eos_coef, losses=losses) -# criterion.to(device) -# postprocessors = {'bbox': PostProcess()} -# -# return model, criterion, postprocessors diff --git a/examples/models/detr/matcher.py b/examples/models/detr/matcher.py deleted file mode 100644 index 48f1177a..00000000 --- a/examples/models/detr/matcher.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved -""" -Modules to compute the matching cost and solve the corresponding LSAP. -""" -import torch -from scipy.optimize import linear_sum_assignment -from torch import nn - -from .util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou - - -class HungarianMatcher(nn.Module): - """This class computes an assignment between the targets and the predictions of the network - - For efficiency reasons, the targets don't include the no_object. Because of this, in general, - there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, - while the others are un-matched (and thus treated as non-objects). - """ - - def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1): - """Creates the matcher - - Params: - cost_class: This is the relative weight of the classification error in the matching cost - cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost - cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost - """ - super().__init__() - self.cost_class = cost_class - self.cost_bbox = cost_bbox - self.cost_giou = cost_giou - assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0" - - @torch.no_grad() - def forward(self, outputs, targets): - """ Performs the matching - - Params: - outputs: This is a dict that contains at least these entries: - "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits - "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates - - targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: - "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth - objects in the target) containing the class labels - "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates - - Returns: - A list of size batch_size, containing tuples of (index_i, index_j) where: - - index_i is the indices of the selected predictions (in order) - - index_j is the indices of the corresponding selected targets (in order) - For each batch element, it holds: - len(index_i) = len(index_j) = min(num_queries, num_target_boxes) - """ - bs, num_queries = outputs["pred_logits"].shape[:2] - - # We flatten to compute the cost matrices in a batch - out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes] - out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] - - # Also concat the target labels and boxes - tgt_ids = torch.cat([v["labels"] for v in targets]) - tgt_bbox = torch.cat([v["boxes"] for v in targets]) - - # Compute the classification cost. Contrary to the loss, we don't use the NLL, - # but approximate it in 1 - proba[target class]. - # The 1 is a constant that doesn't change the matching, it can be ommitted. - cost_class = -out_prob[:, tgt_ids] - - # Compute the L1 cost between boxes - cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) - - # Compute the giou cost betwen boxes - cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)) - - # Final cost matrix - C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou - C = C.view(bs, num_queries, -1).cpu() - - sizes = [len(v["boxes"]) for v in targets] - indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] - return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] - - -def build_matcher(args): - return HungarianMatcher(cost_class=args.set_cost_class, cost_bbox=args.set_cost_bbox, cost_giou=args.set_cost_giou) diff --git a/examples/models/detr/position_encoding.py b/examples/models/detr/position_encoding.py deleted file mode 100644 index bc7d9eb4..00000000 --- a/examples/models/detr/position_encoding.py +++ /dev/null @@ -1,89 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved -""" -Various positional encodings for the transformer. -""" -import math -import torch -from torch import nn - -from .util.misc import NestedTensor - - -class PositionEmbeddingSine(nn.Module): - """ - This is a more standard version of the position embedding, very similar to the one - used by the Attention is all you need paper, generalized to work on images. - """ - def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): - super().__init__() - self.num_pos_feats = num_pos_feats - self.temperature = temperature - self.normalize = normalize - if scale is not None and normalize is False: - raise ValueError("normalize should be True if scale is passed") - if scale is None: - scale = 2 * math.pi - self.scale = scale - - def forward(self, tensor_list: NestedTensor): - x = tensor_list.tensors - mask = tensor_list.mask - assert mask is not None - not_mask = ~mask - y_embed = not_mask.cumsum(1, dtype=torch.float32) - x_embed = not_mask.cumsum(2, dtype=torch.float32) - if self.normalize: - eps = 1e-6 - y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale - x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale - - dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) - dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) - - pos_x = x_embed[:, :, :, None] / dim_t - pos_y = y_embed[:, :, :, None] / dim_t - pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) - pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) - pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) - return pos - - -class PositionEmbeddingLearned(nn.Module): - """ - Absolute pos embedding, learned. - """ - def __init__(self, num_pos_feats=256): - super().__init__() - self.row_embed = nn.Embedding(50, num_pos_feats) - self.col_embed = nn.Embedding(50, num_pos_feats) - self.reset_parameters() - - def reset_parameters(self): - nn.init.uniform_(self.row_embed.weight) - nn.init.uniform_(self.col_embed.weight) - - def forward(self, tensor_list: NestedTensor): - x = tensor_list.tensors - h, w = x.shape[-2:] - i = torch.arange(w, device=x.device) - j = torch.arange(h, device=x.device) - x_emb = self.col_embed(i) - y_emb = self.row_embed(j) - pos = torch.cat([ - x_emb.unsqueeze(0).repeat(h, 1, 1), - y_emb.unsqueeze(1).repeat(1, w, 1), - ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) - return pos - - -def build_position_encoding(args): - N_steps = args.hidden_dim // 2 - if args.position_embedding in ('v2', 'sine'): - # TODO find a better way of exposing other arguments - position_embedding = PositionEmbeddingSine(N_steps, normalize=True) - elif args.position_embedding in ('v3', 'learned'): - position_embedding = PositionEmbeddingLearned(N_steps) - else: - raise ValueError(f"not supported {args.position_embedding}") - - return position_embedding diff --git a/examples/models/detr/transformer.py b/examples/models/detr/transformer.py deleted file mode 100644 index 714c84df..00000000 --- a/examples/models/detr/transformer.py +++ /dev/null @@ -1,296 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved -""" -DETR Transformer class. - -Copy-paste from torch.nn.Transformer with modifications: - * positional encodings are passed in MHattention - * extra LN at the end of encoder is removed - * decoder returns a stack of activations from all decoding layers -""" -import copy -from typing import Optional, List - -import torch -import torch.nn.functional as F -from torch import nn, Tensor - - -class Transformer(nn.Module): - - def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, - num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, - activation="relu", normalize_before=False, - return_intermediate_dec=False): - super().__init__() - - encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, - dropout, activation, normalize_before) - encoder_norm = nn.LayerNorm(d_model) if normalize_before else None - self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) - - decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, - dropout, activation, normalize_before) - decoder_norm = nn.LayerNorm(d_model) - self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, - return_intermediate=return_intermediate_dec) - - self._reset_parameters() - - self.d_model = d_model - self.nhead = nhead - - def _reset_parameters(self): - for p in self.parameters(): - if p.dim() > 1: - nn.init.xavier_uniform_(p) - - def forward(self, src, mask, query_embed, pos_embed): - # flatten NxCxHxW to HWxNxC - bs, c, h, w = src.shape - src = src.flatten(2).permute(2, 0, 1) - pos_embed = pos_embed.flatten(2).permute(2, 0, 1) - query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) - mask = mask.flatten(1) - - tgt = torch.zeros_like(query_embed) - memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) - hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, - pos=pos_embed, query_pos=query_embed) - return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w) - - -class TransformerEncoder(nn.Module): - - def __init__(self, encoder_layer, num_layers, norm=None): - super().__init__() - self.layers = _get_clones(encoder_layer, num_layers) - self.num_layers = num_layers - self.norm = norm - - def forward(self, src, - mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - pos: Optional[Tensor] = None): - output = src - - for layer in self.layers: - output = layer(output, src_mask=mask, - src_key_padding_mask=src_key_padding_mask, pos=pos) - - if self.norm is not None: - output = self.norm(output) - - return output - - -class TransformerDecoder(nn.Module): - - def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): - super().__init__() - self.layers = _get_clones(decoder_layer, num_layers) - self.num_layers = num_layers - self.norm = norm - self.return_intermediate = return_intermediate - - def forward(self, tgt, memory, - tgt_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, - tgt_key_padding_mask: Optional[Tensor] = None, - memory_key_padding_mask: Optional[Tensor] = None, - pos: Optional[Tensor] = None, - query_pos: Optional[Tensor] = None): - output = tgt - - intermediate = [] - - for layer in self.layers: - output = layer(output, memory, tgt_mask=tgt_mask, - memory_mask=memory_mask, - tgt_key_padding_mask=tgt_key_padding_mask, - memory_key_padding_mask=memory_key_padding_mask, - pos=pos, query_pos=query_pos) - if self.return_intermediate: - intermediate.append(self.norm(output)) - - if self.norm is not None: - output = self.norm(output) - if self.return_intermediate: - intermediate.pop() - intermediate.append(output) - - if self.return_intermediate: - return torch.stack(intermediate) - - return output.unsqueeze(0) - - -class TransformerEncoderLayer(nn.Module): - - def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, - activation="relu", normalize_before=False): - super().__init__() - self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) - # Implementation of Feedforward model - self.linear1 = nn.Linear(d_model, dim_feedforward) - self.dropout = nn.Dropout(dropout) - self.linear2 = nn.Linear(dim_feedforward, d_model) - - self.norm1 = nn.LayerNorm(d_model) - self.norm2 = nn.LayerNorm(d_model) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) - - self.activation = _get_activation_fn(activation) - self.normalize_before = normalize_before - - def with_pos_embed(self, tensor, pos: Optional[Tensor]): - return tensor if pos is None else tensor + pos - - def forward_post(self, - src, - src_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - pos: Optional[Tensor] = None): - q = k = self.with_pos_embed(src, pos) - src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, - key_padding_mask=src_key_padding_mask)[0] - src = src + self.dropout1(src2) - src = self.norm1(src) - src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) - src = src + self.dropout2(src2) - src = self.norm2(src) - return src - - def forward_pre(self, src, - src_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - pos: Optional[Tensor] = None): - src2 = self.norm1(src) - q = k = self.with_pos_embed(src2, pos) - src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, - key_padding_mask=src_key_padding_mask)[0] - src = src + self.dropout1(src2) - src2 = self.norm2(src) - src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) - src = src + self.dropout2(src2) - return src - - def forward(self, src, - src_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - pos: Optional[Tensor] = None): - if self.normalize_before: - return self.forward_pre(src, src_mask, src_key_padding_mask, pos) - return self.forward_post(src, src_mask, src_key_padding_mask, pos) - - -class TransformerDecoderLayer(nn.Module): - - def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, - activation="relu", normalize_before=False): - super().__init__() - self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) - self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) - # Implementation of Feedforward model - self.linear1 = nn.Linear(d_model, dim_feedforward) - self.dropout = nn.Dropout(dropout) - self.linear2 = nn.Linear(dim_feedforward, d_model) - - self.norm1 = nn.LayerNorm(d_model) - self.norm2 = nn.LayerNorm(d_model) - self.norm3 = nn.LayerNorm(d_model) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) - self.dropout3 = nn.Dropout(dropout) - - self.activation = _get_activation_fn(activation) - self.normalize_before = normalize_before - - def with_pos_embed(self, tensor, pos: Optional[Tensor]): - return tensor if pos is None else tensor + pos - - def forward_post(self, tgt, memory, - tgt_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, - tgt_key_padding_mask: Optional[Tensor] = None, - memory_key_padding_mask: Optional[Tensor] = None, - pos: Optional[Tensor] = None, - query_pos: Optional[Tensor] = None): - q = k = self.with_pos_embed(tgt, query_pos) - tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, - key_padding_mask=tgt_key_padding_mask)[0] - tgt = tgt + self.dropout1(tgt2) - tgt = self.norm1(tgt) - tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), - key=self.with_pos_embed(memory, pos), - value=memory, attn_mask=memory_mask, - key_padding_mask=memory_key_padding_mask)[0] - tgt = tgt + self.dropout2(tgt2) - tgt = self.norm2(tgt) - tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) - tgt = tgt + self.dropout3(tgt2) - tgt = self.norm3(tgt) - return tgt - - def forward_pre(self, tgt, memory, - tgt_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, - tgt_key_padding_mask: Optional[Tensor] = None, - memory_key_padding_mask: Optional[Tensor] = None, - pos: Optional[Tensor] = None, - query_pos: Optional[Tensor] = None): - tgt2 = self.norm1(tgt) - q = k = self.with_pos_embed(tgt2, query_pos) - tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, - key_padding_mask=tgt_key_padding_mask)[0] - tgt = tgt + self.dropout1(tgt2) - tgt2 = self.norm2(tgt) - tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), - key=self.with_pos_embed(memory, pos), - value=memory, attn_mask=memory_mask, - key_padding_mask=memory_key_padding_mask)[0] - tgt = tgt + self.dropout2(tgt2) - tgt2 = self.norm3(tgt) - tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) - tgt = tgt + self.dropout3(tgt2) - return tgt - - def forward(self, tgt, memory, - tgt_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, - tgt_key_padding_mask: Optional[Tensor] = None, - memory_key_padding_mask: Optional[Tensor] = None, - pos: Optional[Tensor] = None, - query_pos: Optional[Tensor] = None): - if self.normalize_before: - return self.forward_pre(tgt, memory, tgt_mask, memory_mask, - tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) - return self.forward_post(tgt, memory, tgt_mask, memory_mask, - tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) - - -def _get_clones(module, N): - return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) - - -def build_transformer(args): - return Transformer( - d_model=args.hidden_dim, - dropout=args.dropout, - nhead=args.nheads, - dim_feedforward=args.dim_feedforward, - num_encoder_layers=args.enc_layers, - num_decoder_layers=args.dec_layers, - normalize_before=args.pre_norm, - return_intermediate_dec=True, - ) - -def _get_activation_fn(activation): - """Return an activation function given a string""" - if activation == "relu": - return F.relu - if activation == "gelu": - return F.gelu - if activation == "glu": - return F.glu - raise RuntimeError(F"activation should be relu/gelu, not {activation}.") diff --git a/examples/models/detr/util/__init__.py b/examples/models/detr/util/__init__.py deleted file mode 100644 index 168f9979..00000000 --- a/examples/models/detr/util/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved diff --git a/examples/models/detr/util/box_ops.py b/examples/models/detr/util/box_ops.py deleted file mode 100644 index 9c088e5b..00000000 --- a/examples/models/detr/util/box_ops.py +++ /dev/null @@ -1,88 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved -""" -Utilities for bounding box manipulation and GIoU. -""" -import torch -from torchvision.ops.boxes import box_area - - -def box_cxcywh_to_xyxy(x): - x_c, y_c, w, h = x.unbind(-1) - b = [(x_c - 0.5 * w), (y_c - 0.5 * h), - (x_c + 0.5 * w), (y_c + 0.5 * h)] - return torch.stack(b, dim=-1) - - -def box_xyxy_to_cxcywh(x): - x0, y0, x1, y1 = x.unbind(-1) - b = [(x0 + x1) / 2, (y0 + y1) / 2, - (x1 - x0), (y1 - y0)] - return torch.stack(b, dim=-1) - - -# modified from torchvision to also return the union -def box_iou(boxes1, boxes2): - area1 = box_area(boxes1) - area2 = box_area(boxes2) - - lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] - rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] - - wh = (rb - lt).clamp(min=0) # [N,M,2] - inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] - - union = area1[:, None] + area2 - inter - - iou = inter / union - return iou, union - - -def generalized_box_iou(boxes1, boxes2): - """ - Generalized IoU from https://giou.stanford.edu/ - - The boxes should be in [x0, y0, x1, y1] format - - Returns a [N, M] pairwise matrix, where N = len(boxes1) - and M = len(boxes2) - """ - # degenerate boxes gives inf / nan results - # so do an early check - assert (boxes1[:, 2:] >= boxes1[:, :2]).all() - assert (boxes2[:, 2:] >= boxes2[:, :2]).all() - iou, union = box_iou(boxes1, boxes2) - - lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) - rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) - - wh = (rb - lt).clamp(min=0) # [N,M,2] - area = wh[:, :, 0] * wh[:, :, 1] - - return iou - (area - union) / area - - -def masks_to_boxes(masks): - """Compute the bounding boxes around the provided masks - - The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. - - Returns a [N, 4] tensors, with the boxes in xyxy format - """ - if masks.numel() == 0: - return torch.zeros((0, 4), device=masks.device) - - h, w = masks.shape[-2:] - - y = torch.arange(0, h, dtype=torch.float) - x = torch.arange(0, w, dtype=torch.float) - y, x = torch.meshgrid(y, x) - - x_mask = (masks * x.unsqueeze(0)) - x_max = x_mask.flatten(1).max(-1)[0] - x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] - - y_mask = (masks * y.unsqueeze(0)) - y_max = y_mask.flatten(1).max(-1)[0] - y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] - - return torch.stack([x_min, y_min, x_max, y_max], 1) diff --git a/examples/models/detr/util/misc.py b/examples/models/detr/util/misc.py deleted file mode 100644 index 1d4e5eb1..00000000 --- a/examples/models/detr/util/misc.py +++ /dev/null @@ -1,467 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved -""" -Misc functions, including distributed helpers. - -Mostly copy-paste from torchvision references. -""" -import os -import subprocess -import time -from collections import defaultdict, deque -import datetime -import pickle -from typing import Optional, List - -import torch -import torch.distributed as dist -from torch import Tensor - -# needed due to empty tensor bug in pytorch and torchvision 0.5 -import torchvision -if float(torchvision.__version__[:3]) < 0.7: - from torchvision.ops import _new_empty_tensor - from torchvision.ops.misc import _output_size - - -class SmoothedValue(object): - """Track a series of values and provide access to smoothed values over a - window or the global series average. - """ - - def __init__(self, window_size=20, fmt=None): - if fmt is None: - fmt = "{median:.4f} ({global_avg:.4f})" - self.deque = deque(maxlen=window_size) - self.total = 0.0 - self.count = 0 - self.fmt = fmt - - def update(self, value, n=1): - self.deque.append(value) - self.count += n - self.total += value * n - - def synchronize_between_processes(self): - """ - Warning: does not synchronize the deque! - """ - if not is_dist_avail_and_initialized(): - return - t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') - dist.barrier() - dist.all_reduce(t) - t = t.tolist() - self.count = int(t[0]) - self.total = t[1] - - @property - def median(self): - d = torch.tensor(list(self.deque)) - return d.median().item() - - @property - def avg(self): - d = torch.tensor(list(self.deque), dtype=torch.float32) - return d.mean().item() - - @property - def global_avg(self): - return self.total / self.count - - @property - def max(self): - return max(self.deque) - - @property - def value(self): - return self.deque[-1] - - def __str__(self): - return self.fmt.format( - median=self.median, - avg=self.avg, - global_avg=self.global_avg, - max=self.max, - value=self.value) - - -def all_gather(data): - """ - Run all_gather on arbitrary picklable data (not necessarily tensors) - Args: - data: any picklable object - Returns: - list[data]: list of data gathered from each rank - """ - world_size = get_world_size() - if world_size == 1: - return [data] - - # serialized to a Tensor - buffer = pickle.dumps(data) - storage = torch.ByteStorage.from_buffer(buffer) - tensor = torch.ByteTensor(storage).to("cuda") - - # obtain Tensor size of each rank - local_size = torch.tensor([tensor.numel()], device="cuda") - size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] - dist.all_gather(size_list, local_size) - size_list = [int(size.item()) for size in size_list] - max_size = max(size_list) - - # receiving Tensor from all ranks - # we pad the tensor because torch all_gather does not support - # gathering tensors of different shapes - tensor_list = [] - for _ in size_list: - tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) - if local_size != max_size: - padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") - tensor = torch.cat((tensor, padding), dim=0) - dist.all_gather(tensor_list, tensor) - - data_list = [] - for size, tensor in zip(size_list, tensor_list): - buffer = tensor.cpu().numpy().tobytes()[:size] - data_list.append(pickle.loads(buffer)) - - return data_list - - -def reduce_dict(input_dict, average=True): - """ - Args: - input_dict (dict): all the values will be reduced - average (bool): whether to do average or sum - Reduce the values in the dictionary from all processes so that all processes - have the averaged results. Returns a dict with the same fields as - input_dict, after reduction. - """ - world_size = get_world_size() - if world_size < 2: - return input_dict - with torch.no_grad(): - names = [] - values = [] - # sort the keys so that they are consistent across processes - for k in sorted(input_dict.keys()): - names.append(k) - values.append(input_dict[k]) - values = torch.stack(values, dim=0) - dist.all_reduce(values) - if average: - values /= world_size - reduced_dict = {k: v for k, v in zip(names, values)} - return reduced_dict - - -class MetricLogger(object): - def __init__(self, delimiter="\t"): - self.meters = defaultdict(SmoothedValue) - self.delimiter = delimiter - - def update(self, **kwargs): - for k, v in kwargs.items(): - if isinstance(v, torch.Tensor): - v = v.item() - assert isinstance(v, (float, int)) - self.meters[k].update(v) - - def __getattr__(self, attr): - if attr in self.meters: - return self.meters[attr] - if attr in self.__dict__: - return self.__dict__[attr] - raise AttributeError("'{}' object has no attribute '{}'".format( - type(self).__name__, attr)) - - def __str__(self): - loss_str = [] - for name, meter in self.meters.items(): - loss_str.append( - "{}: {}".format(name, str(meter)) - ) - return self.delimiter.join(loss_str) - - def synchronize_between_processes(self): - for meter in self.meters.values(): - meter.synchronize_between_processes() - - def add_meter(self, name, meter): - self.meters[name] = meter - - def log_every(self, iterable, print_freq, header=None): - i = 0 - if not header: - header = '' - start_time = time.time() - end = time.time() - iter_time = SmoothedValue(fmt='{avg:.4f}') - data_time = SmoothedValue(fmt='{avg:.4f}') - space_fmt = ':' + str(len(str(len(iterable)))) + 'd' - if torch.cuda.is_available(): - log_msg = self.delimiter.join([ - header, - '[{0' + space_fmt + '}/{1}]', - 'eta: {eta}', - '{meters}', - 'time: {time}', - 'data: {data}', - 'max mem: {memory:.0f}' - ]) - else: - log_msg = self.delimiter.join([ - header, - '[{0' + space_fmt + '}/{1}]', - 'eta: {eta}', - '{meters}', - 'time: {time}', - 'data: {data}' - ]) - MB = 1024.0 * 1024.0 - for obj in iterable: - data_time.update(time.time() - end) - yield obj - iter_time.update(time.time() - end) - if i % print_freq == 0 or i == len(iterable) - 1: - eta_seconds = iter_time.global_avg * (len(iterable) - i) - eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) - if torch.cuda.is_available(): - print(log_msg.format( - i, len(iterable), eta=eta_string, - meters=str(self), - time=str(iter_time), data=str(data_time), - memory=torch.cuda.max_memory_allocated() / MB)) - else: - print(log_msg.format( - i, len(iterable), eta=eta_string, - meters=str(self), - time=str(iter_time), data=str(data_time))) - i += 1 - end = time.time() - total_time = time.time() - start_time - total_time_str = str(datetime.timedelta(seconds=int(total_time))) - print('{} Total time: {} ({:.4f} s / it)'.format( - header, total_time_str, total_time / len(iterable))) - - -def get_sha(): - cwd = os.path.dirname(os.path.abspath(__file__)) - - def _run(command): - return subprocess.check_output(command, cwd=cwd).decode('ascii').strip() - sha = 'N/A' - diff = "clean" - branch = 'N/A' - try: - sha = _run(['git', 'rev-parse', 'HEAD']) - subprocess.check_output(['git', 'diff'], cwd=cwd) - diff = _run(['git', 'diff-index', 'HEAD']) - diff = "has uncommited changes" if diff else "clean" - branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) - except Exception: - pass - message = f"sha: {sha}, status: {diff}, branch: {branch}" - return message - - -def collate_fn(batch): - batch = list(zip(*batch)) - batch[0] = nested_tensor_from_tensor_list(batch[0]) - return tuple(batch) - - -def _max_by_axis(the_list): - # type: (List[List[int]]) -> List[int] - maxes = the_list[0] - for sublist in the_list[1:]: - for index, item in enumerate(sublist): - maxes[index] = max(maxes[index], item) - return maxes - - -class NestedTensor(object): - def __init__(self, tensors, mask: Optional[Tensor]): - self.tensors = tensors - self.mask = mask - - def to(self, device): - # type: (Device) -> NestedTensor # noqa - cast_tensor = self.tensors.to(device) - mask = self.mask - if mask is not None: - assert mask is not None - cast_mask = mask.to(device) - else: - cast_mask = None - return NestedTensor(cast_tensor, cast_mask) - - def decompose(self): - return self.tensors, self.mask - - def __repr__(self): - return str(self.tensors) - - -def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): - # TODO make this more general - if tensor_list[0].ndim == 3: - if torchvision._is_tracing(): - # nested_tensor_from_tensor_list() does not export well to ONNX - # call _onnx_nested_tensor_from_tensor_list() instead - return _onnx_nested_tensor_from_tensor_list(tensor_list) - - # TODO make it support different-sized images - max_size = _max_by_axis([list(img.shape) for img in tensor_list]) - # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) - batch_shape = [len(tensor_list)] + max_size - b, c, h, w = batch_shape - dtype = tensor_list[0].dtype - device = tensor_list[0].device - tensor = torch.zeros(batch_shape, dtype=dtype, device=device) - mask = torch.ones((b, h, w), dtype=torch.bool, device=device) - for img, pad_img, m in zip(tensor_list, tensor, mask): - pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) - m[: img.shape[1], :img.shape[2]] = False - else: - raise ValueError('not supported') - return NestedTensor(tensor, mask) - - -# _onnx_nested_tensor_from_tensor_list() is an implementation of -# nested_tensor_from_tensor_list() that is supported by ONNX tracing. -@torch.jit.unused -def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: - max_size = [] - for i in range(tensor_list[0].dim()): - max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64) - max_size.append(max_size_i) - max_size = tuple(max_size) - - # work around for - # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) - # m[: img.shape[1], :img.shape[2]] = False - # which is not yet supported in onnx - padded_imgs = [] - padded_masks = [] - for img in tensor_list: - padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] - padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) - padded_imgs.append(padded_img) - - m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) - padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) - padded_masks.append(padded_mask.to(torch.bool)) - - tensor = torch.stack(padded_imgs) - mask = torch.stack(padded_masks) - - return NestedTensor(tensor, mask=mask) - - -def setup_for_distributed(is_master): - """ - This function disables printing when not in master process - """ - import builtins as __builtin__ - builtin_print = __builtin__.print - - def print(*args, **kwargs): - force = kwargs.pop('force', False) - if is_master or force: - builtin_print(*args, **kwargs) - - __builtin__.print = print - - -def is_dist_avail_and_initialized(): - if not dist.is_available(): - return False - if not dist.is_initialized(): - return False - return True - - -def get_world_size(): - if not is_dist_avail_and_initialized(): - return 1 - return dist.get_world_size() - - -def get_rank(): - if not is_dist_avail_and_initialized(): - return 0 - return dist.get_rank() - - -def is_main_process(): - return get_rank() == 0 - - -def save_on_master(*args, **kwargs): - if is_main_process(): - torch.save(*args, **kwargs) - - -def init_distributed_mode(args): - if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: - args.rank = int(os.environ["RANK"]) - args.world_size = int(os.environ['WORLD_SIZE']) - args.gpu = int(os.environ['LOCAL_RANK']) - elif 'SLURM_PROCID' in os.environ: - args.rank = int(os.environ['SLURM_PROCID']) - args.gpu = args.rank % torch.cuda.device_count() - else: - print('Not using distributed mode') - args.distributed = False - return - - args.distributed = True - - torch.cuda.set_device(args.gpu) - args.dist_backend = 'nccl' - print('| distributed init (rank {}): {}'.format( - args.rank, args.dist_url), flush=True) - torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, - world_size=args.world_size, rank=args.rank) - torch.distributed.barrier() - setup_for_distributed(args.rank == 0) - - -@torch.no_grad() -def accuracy(output, target, topk=(1,)): - """Computes the precision@k for the specified values of k""" - if target.numel() == 0: - return [torch.zeros([], device=output.device)] - maxk = max(topk) - batch_size = target.size(0) - - _, pred = output.topk(maxk, 1, True, True) - pred = pred.t() - correct = pred.eq(target.view(1, -1).expand_as(pred)) - - res = [] - for k in topk: - correct_k = correct[:k].view(-1).float().sum(0) - res.append(correct_k.mul_(100.0 / batch_size)) - return res - - -def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): - # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor - """ - Equivalent to nn.functional.interpolate, but with support for empty batch sizes. - This will eventually be supported natively by PyTorch, and this - class can go away. - """ - if float(torchvision.__version__[:3]) < 0.7: - if input.numel() > 0: - return torch.nn.functional.interpolate( - input, size, scale_factor, mode, align_corners - ) - - output_shape = _output_size(2, input, size, scale_factor) - output_shape = list(input.shape[:-2]) + list(output_shape) - return _new_empty_tensor(input, output_shape) - else: - return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) diff --git a/examples/models/detr/util/plot_utils.py b/examples/models/detr/util/plot_utils.py deleted file mode 100644 index 0f24bed0..00000000 --- a/examples/models/detr/util/plot_utils.py +++ /dev/null @@ -1,107 +0,0 @@ -""" -Plotting utilities to visualize training logs. -""" -import torch -import pandas as pd -import numpy as np -import seaborn as sns -import matplotlib.pyplot as plt - -from pathlib import Path, PurePath - - -def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'): - ''' - Function to plot specific fields from training log(s). Plots both training and test results. - - :: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file - - fields = which results to plot from each log file - plots both training and test for each field. - - ewm_col = optional, which column to use as the exponential weighted smoothing of the plots - - log_name = optional, name of log file if different than default 'log.txt'. - - :: Outputs - matplotlib plots of results in fields, color coded for each log file. - - solid lines are training results, dashed lines are test results. - - ''' - func_name = "plot_utils.py::plot_logs" - - # verify logs is a list of Paths (list[Paths]) or single Pathlib object Path, - # convert single Path to list to avoid 'not iterable' error - - if not isinstance(logs, list): - if isinstance(logs, PurePath): - logs = [logs] - print(f"{func_name} info: logs param expects a list argument, converted to list[Path].") - else: - raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \ - Expect list[Path] or single Path obj, received {type(logs)}") - - # Quality checks - verify valid dir(s), that every item in list is Path object, and that log_name exists in each dir - for i, dir in enumerate(logs): - if not isinstance(dir, PurePath): - raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}") - if not dir.exists(): - raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}") - # verify log_name exists - fn = Path(dir / log_name) - if not fn.exists(): - print(f"-> missing {log_name}. Have you gotten to Epoch 1 in training?") - print(f"--> full path of missing log file: {fn}") - return - - # load log file(s) and plot - dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs] - - fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5)) - - for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))): - for j, field in enumerate(fields): - if field == 'mAP': - coco_eval = pd.DataFrame( - np.stack(df.test_coco_eval_bbox.dropna().values)[:, 1] - ).ewm(com=ewm_col).mean() - axs[j].plot(coco_eval, c=color) - else: - df.interpolate().ewm(com=ewm_col).mean().plot( - y=[f'train_{field}', f'test_{field}'], - ax=axs[j], - color=[color] * 2, - style=['-', '--'] - ) - for ax, field in zip(axs, fields): - ax.legend([Path(p).name for p in logs]) - ax.set_title(field) - - -def plot_precision_recall(files, naming_scheme='iter'): - if naming_scheme == 'exp_id': - # name becomes exp_id - names = [f.parts[-3] for f in files] - elif naming_scheme == 'iter': - names = [f.stem for f in files] - else: - raise ValueError(f'not supported {naming_scheme}') - fig, axs = plt.subplots(ncols=2, figsize=(16, 5)) - for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names): - data = torch.load(f) - # precision is n_iou, n_points, n_cat, n_area, max_det - precision = data['precision'] - recall = data['params'].recThrs - scores = data['scores'] - # take precision for all classes, all areas and 100 detections - precision = precision[0, :, :, 0, -1].mean(1) - scores = scores[0, :, :, 0, -1].mean(1) - prec = precision.mean() - rec = data['recall'][0, :, 0, -1].mean() - print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' + - f'score={scores.mean():0.3f}, ' + - f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}' - ) - axs[0].plot(recall, precision, c=color) - axs[1].plot(recall, scores, c=color) - - axs[0].set_title('Precision / Recall') - axs[0].legend(names) - axs[1].set_title('Scores / Recall') - axs[1].legend(names) - return fig, axs diff --git a/examples/models/initializer.py b/examples/models/initializer.py index c8ee9245..e38aede6 100644 --- a/examples/models/initializer.py +++ b/examples/models/initializer.py @@ -77,17 +77,13 @@ def initialize_model(config, d_out, is_featurizer=False): assert not is_featurizer, "Featurizer not supported for logistic regression" model = nn.Linear(out_features=d_out, **config.model_kwargs) - elif config.model == 'detr': - if is_featurizer: # TODO - raise NotImplementedError('Featurizer not implemented for detection yet') - else: - model = initialize_detr_model(config, d_out) elif config.model == 'fasterrcnn': if is_featurizer: # TODO raise NotImplementedError('Featurizer not implemented for detection yet') else: model = initialize_fasterrcnn_model(config, d_out) model.needs_y = True + else: raise ValueError(f'Model: {config.model} not recognized.') @@ -164,58 +160,3 @@ def initialize_fasterrcnn_model(config, d_out): model = fasterrcnn_resnet50_fpn(pretrained=config.model_kwargs["pretrained"],num_classes=d_out) return model - - -def initialize_detr_model(config, d_out): - - from models.detr.backbone import Backbone, Joiner - from models.detr.position_encoding import PositionEmbeddingSine - from models.detr.transformer import Transformer - from models.detr.detr import DETR - - position_embedding = PositionEmbeddingSine( - config.model_kwargs['hidden_dim'] // 2, - normalize=True) - - backbone = Backbone( - name=config.model_kwargs['backbone'], - train_backbone=config.model_kwargs['train_backbone'], - return_interm_layers=False, # No segmentation - dilation=config.model_kwargs['dilation']) - num_channels = backbone.num_channels - backbone = Joiner(backbone, position_embedding) - backbone.num_channels = num_channels - - transformer = Transformer( - d_model=config.model_kwargs['hidden_dim'], - dropout=config.model_kwargs['dropout'], - nhead=config.model_kwargs['nheads'], - dim_feedforward=config.model_kwargs['dim_feedforward'], - num_encoder_layers=config.model_kwargs['enc_layers'], - num_decoder_layers=config.model_kwargs['dec_layers'], - normalize_before=config.model_kwargs['pre_norm'], - return_intermediate_dec=True, - ) - - model = DETR( - backbone, - transformer, - num_classes=d_out, - num_queries=config.model_kwargs['n_queries'], - aux_loss=config.model_kwargs['aux_loss'], - ) - - if config.model_kwargs['pretrained']: - # Calling torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True, num_classes=d_out) does not work - # due to a ModuleNotFoundError. Perhaps some configuration error there. - # So we have to do it manually. - checkpoint = torch.hub.load_state_dict_from_url( - url='https://dl.fbaipublicfiles.com/detr/detr-r50-e632da11.pth', - map_location='cpu', - check_hash=True) - del checkpoint["model"]["query_embed.weight"] - del checkpoint["model"]["class_embed.weight"] - del checkpoint["model"]["class_embed.bias"] - model.load_state_dict(checkpoint["model"], strict=False) - - return model diff --git a/wilds/datasets/gwhd_dataset.py b/wilds/datasets/gwhd_dataset.py index c0e7414c..1d3622b3 100644 --- a/wilds/datasets/gwhd_dataset.py +++ b/wilds/datasets/gwhd_dataset.py @@ -107,25 +107,8 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' "labels": torch.tensor([1]*len(list(boxes.split(";")))).long() } if type(boxes) != float else { "boxes": torch.empty(0,4), - # "labels": torch.empty(0,1,dtype=torch.long) "labels": torch.empty(0,dtype=torch.long) } for boxes in labels] - # TODO: Figure out empty images - - # The above boxes are (x_min,y_min,x_max,y_max) - # Convert labels into (center_x, center_y, w, h) normalized, which is what DETR expects - # TODO: If it's not standard, we can probably put this in a transform somewhere - """ - for label in labels: - boxes = label['boxes'] - center_x = (boxes[:, 0] + boxes[:, 2]) / 2 / self._original_resolution[0] - center_y = (boxes[:, 1] + boxes[:, 3]) / 2 / self._original_resolution[1] - width = (boxes[:, 2] - boxes[:, 0]) / self._original_resolution[0] - height = (boxes[:, 3] - boxes[:, 1]) / self._original_resolution[1] - label['boxes'] = torch.stack((center_x, center_y, width, height), dim=1) - """ - # num_boxes = [len(example['boxes']) for example in labels] - # print(f'Max num_boxes is {max(num_boxes)}') self._y_array.extend(labels) self._metadata_array.extend(list(df['group'].values)) From ea53135f613236e842771d079ca9af16ad3f9d5d Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Thu, 29 Apr 2021 13:04:49 -0700 Subject: [PATCH 165/244] clean up spacing and comments --- examples/models/detection/fasterrcnn.py | 32 ++++++------------------- examples/run_expt.py | 2 -- examples/train.py | 13 ---------- 3 files changed, 7 insertions(+), 40 deletions(-) diff --git a/examples/models/detection/fasterrcnn.py b/examples/models/detection/fasterrcnn.py index 3a9724a9..d553ea0f 100644 --- a/examples/models/detection/fasterrcnn.py +++ b/examples/models/detection/fasterrcnn.py @@ -85,7 +85,7 @@ def batch_concat_box_prediction_layers(box_cls, box_regression): return new_box_cls, new_box_regression - + class RegionProposalNetworkWILDS(RegionProposalNetwork): def __init__(self, anchor_generator, @@ -113,10 +113,6 @@ def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets) objectness_loss (Tensor) box_loss (Tensor) """ - - - - objectness, pred_bbox_deltas = batch_concat_box_prediction_layers(objectness, pred_bbox_deltas) objectness_loss = [] @@ -130,7 +126,6 @@ def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets) sampled_neg_inds = torch.where(torch.cat(sampled_neg_inds, dim=0))[0] sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0) - box_loss.append(det_utils.smooth_l1_loss( pred_bbox_deltas_[sampled_pos_inds], regression_targets_[sampled_pos_inds], @@ -144,7 +139,7 @@ def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets) )) return torch.stack(objectness_loss), torch.stack(box_loss) - + def forward(self, images, # type: ImageList features, # type: Dict[str, Tensor] @@ -231,7 +226,6 @@ def fastrcnn_loss(class_logits, box_regression, labels, regression_targets): labels_pos = labels_[sampled_pos_inds_subset] N, num_classes = class_logits_.shape - box_regression_ = box_regression_.reshape(N, -1, 4) box_loss_ = det_utils.smooth_l1_loss( @@ -247,7 +241,6 @@ def fastrcnn_loss(class_logits, box_regression, labels, regression_targets): class RoIHeadsWILDS(RoIHeads): def __init__(self, box_roi_pool, box_head, box_predictor, box_fg_iou_thresh, box_bg_iou_thresh,box_batch_size_per_image,box_positive_fraction,bbox_reg_weights,box_score_thresh,box_nms_thresh,box_detections_per_img): - super().__init__(box_roi_pool, box_head, box_predictor, box_fg_iou_thresh, box_bg_iou_thresh, box_batch_size_per_image, box_positive_fraction, @@ -277,7 +270,6 @@ def forward(self, if self.has_keypoint(): assert t["keypoints"].dtype == torch.float32, 'target keypoints must of float type' - # here batch is maintained if self.training: proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets) @@ -286,8 +278,6 @@ def forward(self, regression_targets = None matched_idxs = None - - box_features = self.box_roi_pool(features, proposals, image_shapes) box_features = self.box_head(box_features) @@ -306,7 +296,6 @@ def forward(self, "loss_box_reg": loss_box_reg } - boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes) num_images = len(boxes) for i in range(num_images): @@ -318,9 +307,8 @@ def forward(self, } ) - return result, losses - + def fasterrcnn_resnet50_fpn(pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, trainable_backbone_layers=3, **kwargs): @@ -338,17 +326,13 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True, progress=progress) model.load_state_dict(state_dict) - # get number of input features for the classifier in_features = model.roi_heads.box_predictor.cls_score.in_features # replace the pre-trained head with a new one model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes+1) - return model - - class FastWILDS(GeneralizedRCNN): def __init__(self, backbone, num_classes=None, # transform parameters @@ -425,14 +409,14 @@ def __init__(self, backbone, num_classes=None, box_predictor = FastRCNNPredictor( representation_size, num_classes) - + roi_heads = RoIHeadsWILDS( box_roi_pool, box_head, box_predictor, box_fg_iou_thresh, box_bg_iou_thresh, box_batch_size_per_image, box_positive_fraction, bbox_reg_weights, box_score_thresh, box_nms_thresh, box_detections_per_img) - + image_mean = [0., 0., 0.] # small trick because images are already normalized image_std = [1., 1., 1.] @@ -441,7 +425,7 @@ def __init__(self, backbone, num_classes=None, super(FastWILDS, self).__init__(backbone, rpn, roi_heads, transform) # Set your own forward pass def forward(self, images, targets=None): - + if self.training: if targets is None: @@ -464,7 +448,7 @@ def forward(self, images, targets=None): assert len(val) == 2 original_image_sizes.append((val[0], val[1])) - + images, targets = self.transform(images, targets) # Check for degenerate boxes @@ -525,5 +509,3 @@ def forward(self, outputs, targets): return elementwise_loss - - diff --git a/examples/run_expt.py b/examples/run_expt.py index e5eb2fa2..ff0ba47d 100644 --- a/examples/run_expt.py +++ b/examples/run_expt.py @@ -188,7 +188,6 @@ def main(): transform=transform) if split == 'train': - datasets[split]['loader'] = get_train_loader( loader=config.train_loader, dataset=datasets[split]['dataset'], @@ -198,7 +197,6 @@ def main(): distinct_groups=config.distinct_groups, n_groups_per_batch=config.n_groups_per_batch, **config.loader_kwargs) - else: datasets[split]['loader'] = get_eval_loader( loader=config.eval_loader, diff --git a/examples/train.py b/examples/train.py index 963049e3..93cc076e 100644 --- a/examples/train.py +++ b/examples/train.py @@ -2,7 +2,6 @@ from tqdm import tqdm import torch from utils import save_model, save_pred, get_pred_prefix, get_model_prefix, detach_and_clone, collate_list -import torch.autograd.profiler as profiler from configs.supported import process_outputs_functions def run_epoch(algorithm, dataset, general_logger, epoch, config, train): @@ -26,9 +25,6 @@ def run_epoch(algorithm, dataset, general_logger, epoch, config, train): batch_idx = 0 iterator = tqdm(dataset['loader']) if config.progress_bar else dataset['loader'] - # import psutil - # process = psutil.Process(os.getpid()) - for batch in iterator: if train: batch_results = algorithm.update(batch) @@ -48,14 +44,6 @@ def run_epoch(algorithm, dataset, general_logger, epoch, config, train): if train and (batch_idx+1) % config.log_every==0: log_results(algorithm, dataset, general_logger, epoch, batch_idx) - # t = torch.cuda.get_device_properties(0).total_memory - # r = torch.cuda.memory_reserved(0) - # a = torch.cuda.memory_allocated(0) - # f = r-a # free inside reserved - # print(f'Total: {f:10} Reserved: {r:10} Allocated: {a:10} Free: {f:10}') - # - # mem = process.memory_info().rss - # print(f'Mem: {mem / 1024 / 1024:6.1f}M') batch_idx += 1 @@ -63,7 +51,6 @@ def run_epoch(algorithm, dataset, general_logger, epoch, config, train): epoch_y_true = collate_list(epoch_y_true) epoch_metadata = collate_list(epoch_metadata) - results, results_str = dataset['dataset'].eval( epoch_y_pred, epoch_y_true, From 1939ff17a7e1c662ea5bf9f96cb55799ea84c026 Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Thu, 29 Apr 2021 13:54:27 -0700 Subject: [PATCH 166/244] clean up prediction function in detection accuracy --- examples/configs/datasets.py | 3 +- examples/utils.py | 10 +++---- wilds/common/data_loaders.py | 19 ++++--------- wilds/common/metrics/all_metrics.py | 43 +++++++++-------------------- 4 files changed, 26 insertions(+), 49 deletions(-) diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index 09944562..17bb8640 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -292,7 +292,7 @@ 'pretrained': True}, 'loss_function': 'fasterrcnn_criterion', 'groupby_fields': ['location'], - 'val_metric': 'detection_accuracy_avg', # TODO + 'val_metric': 'detection_acc_avg', # TODO 'val_metric_decreasing': False, 'algo_log_metric': None, # TODO 'optimizer': 'Adam', @@ -306,6 +306,7 @@ 'num_workers': 1, 'pin_memory': True, }, + 'process_outputs_function': None, } } diff --git a/examples/utils.py b/examples/utils.py index 67a9358c..73fa1b12 100644 --- a/examples/utils.py +++ b/examples/utils.py @@ -182,9 +182,7 @@ def save_pred(y_pred, path_prefix): df = pd.DataFrame(y_pred.numpy()) df.to_csv(path_prefix + '.csv', index=False, header=False) # Dictionary - elif isinstance(y_pred, dict): - torch.save(y_pred, path_prefix + '.pth') - elif isinstance(y_pred, list): + elif isinstance(y_pred, dict) or isinstance(y_pred, list): torch.save(y_pred, path_prefix + '.pth') else: raise TypeError("Invalid type for save_pred") @@ -241,9 +239,11 @@ def collate_list(vec): """ If vec is a list of Tensors, it concatenates them all along the first dimension. - If vec is a list of lists, it joins these lists together, but does not attempt to recursively collate. This allows each element of the list to be, e.g., its own dict. + If vec is a list of lists, it joins these lists together, but does not attempt to + recursively collate. This allows each element of the list to be, e.g., its own dict. - If vec is a list of dicts (with the same keys in each dict), it returns a single dict with the same keys. For each key, it recursively collates all entries in the list. + If vec is a list of dicts (with the same keys in each dict), it returns a single dict + with the same keys. For each key, it recursively collates all entries in the list. """ if not isinstance(vec, list): raise TypeError("collate_list must take in a list") diff --git a/wilds/common/data_loaders.py b/wilds/common/data_loaders.py index 12e7e0c3..0cb23cee 100644 --- a/wilds/common/data_loaders.py +++ b/wilds/common/data_loaders.py @@ -4,20 +4,20 @@ from torch.utils.data.sampler import WeightedRandomSampler, SubsetRandomSampler from wilds.common.utils import get_counts, split_into_groups -def get_train_loader(loader, dataset, batch_size, +def get_train_loader(loader, dataset, batch_size, uniform_over_groups=None, grouper=None, distinct_groups=True, n_groups_per_batch=None, **loader_kwargs): """ Constructs and returns the data loader for training. Args: - loader (str): Loader type. 'standard' for standard loaders and 'group' for group loaders, - which first samples groups and then samples a fixed number of examples belonging + which first samples groups and then samples a fixed number of examples belonging to each group. - dataset (WILDSDataset or WILDSSubset): Data - batch_size (int): Batch size - - uniform_over_groups (None or bool): Whether to sample the groups uniformly or according to the + - uniform_over_groups (None or bool): Whether to sample the groups uniformly or according to the natural data distribution. - Setting to None applies the defaults for each type of loaders. - For standard loaders, the default is False. For group loaders, + Setting to None applies the defaults for each type of loaders. + For standard loaders, the default is False. For group loaders, the default is True. - grouper (Grouper): Grouper used for group loaders or for uniform_over_groups=True - distinct_groups (bool): Whether to sample distinct_groups within each minibatch for group loaders. @@ -63,7 +63,6 @@ def get_train_loader(loader, dataset, batch_size, raise ValueError(f'n_groups_per_batch was set to {n_groups_per_batch} but there are only {grouper.n_groups} groups specified.') group_ids = grouper.metadata_to_group(dataset.metadata_array) - batch_sampler = GroupSampler( group_ids=group_ids, batch_size=batch_size, @@ -71,7 +70,6 @@ def get_train_loader(loader, dataset, batch_size, uniform_over_groups=uniform_over_groups, distinct_groups=distinct_groups) - return DataLoader(dataset, shuffle=None, sampler=None, @@ -84,7 +82,7 @@ def get_eval_loader(loader, dataset, batch_size, grouper=None, **loader_kwargs): """ Constructs and returns the data loader for evaluation. Args: - - loader (str): Loader type. 'standard' for standard loaders. + - loader (str): Loader type. 'standard' for standard loaders. - dataset (WILDSDataset or WILDSSubset): Data - batch_size (int): Batch size - loader_kwargs: kwargs passed into torch DataLoader initialization. @@ -106,7 +104,6 @@ class GroupSampler: then sampling data from those groups. It drops the last batch if it's incomplete. """ - def __init__(self, group_ids, batch_size, n_groups_per_batch, uniform_over_groups, distinct_groups): @@ -131,16 +128,13 @@ def __init__(self, group_ids, batch_size, n_groups_per_batch, self.group_prob = unique_counts.numpy() / unique_counts.numpy().sum() def __iter__(self): - for batch_id in range(self.num_batches): - # Note that we are selecting group indices rather than groups groups_for_batch = np.random.choice( len(self.unique_groups), size=self.n_groups_per_batch, replace=(not self.distinct_groups), p=self.group_prob) - sampled_ids = [ np.random.choice( self.group_indices[group], @@ -151,7 +145,6 @@ def __iter__(self): # Flatten sampled_ids = np.concatenate(sampled_ids) - yield sampled_ids def __len__(self): diff --git a/wilds/common/metrics/all_metrics.py b/wilds/common/metrics/all_metrics.py index 8989a3d9..ec93cc9f 100644 --- a/wilds/common/metrics/all_metrics.py +++ b/wilds/common/metrics/all_metrics.py @@ -147,6 +147,9 @@ def worst(self, metrics): return minimum(metrics) class DummyMetric(Metric): + """ + For testing purposes. This Metric always returns -1. + """ def __init__(self, prediction_fn=None, name=None): self.prediction_fn = prediction_fn if name is None: @@ -170,61 +173,43 @@ class DetectionAccuracy(ElementwiseMetric): Given a specific Intersection over union threshold, determine the accuracy achieved for a one-class detector """ - def __init__(self, prediction_fn=None, iou_threshold=0.5,score_threshold=0.5, name=None): - self.prediction_fn = prediction_fn + def __init__(self, iou_threshold=0.5, score_threshold=0.5, name=None): self.iou_threshold = iou_threshold self.score_threshold = score_threshold if name is None: - name = "detection_accuracy" + name = "detection_acc" super().__init__(name=name) - def _compute_element_wise(self, y_pred ,y_true ): - - + def _compute_element_wise(self, y_pred, y_true): batch_results = [] - for src_boxes, target in zip( y_true, y_pred): + for src_boxes, target in zip(y_true, y_pred): target_boxes = target["boxes"] target_scores = target["scores"] - # Here should be prediction_fn ? - #target_scores = F.softmax(target_logits, dim=1)[..., 0] pred_boxes = target_boxes[target_scores > self.score_threshold] - det_accuracy = torch.mean(torch.stack([ self._accuracy(src_boxes["boxes"],pred_boxes,iou_thr) for iou_thr in np.arange(0.5,0.51,0.05)])) batch_results.append(det_accuracy) return torch.tensor(batch_results) - def _accuracy(self, src_boxes,pred_boxes , iou_threshold): total_gt = len(src_boxes) total_pred = len(pred_boxes) - - if total_gt > 0 and total_pred > 0: - # Define the matcher and distance matrix based on iou matcher = Matcher(iou_threshold,iou_threshold,allow_low_quality_matches=False) - - #src_boxes = box_convert(src_boxes , "cxcywh" ,"xyxy") - #pred_boxes = box_convert(pred_boxes , "cxcywh" ,"xyxy") - - match_quality_matrix = box_iou(src_boxes,pred_boxes) - results = matcher(match_quality_matrix) - true_positive = torch.count_nonzero(results.unique() != -1) matched_elements = results[results > -1] - #in Matcher, a pred element can be matched only twice - false_positive = torch.count_nonzero(results == -1) + ( len(matched_elements) - len(matched_elements.unique())) + false_positive = ( + torch.count_nonzero(results == -1) + + (len(matched_elements) - len(matched_elements.unique())) + ) false_negative = total_gt - true_positive - acc= true_positive / ( true_positive + false_positive + false_negative ) - - - return true_positive / ( true_positive + false_positive + false_negative ) - + acc = true_positive / ( true_positive + false_positive + false_negative ) + return true_positive / ( true_positive + false_positive + false_negative ) elif total_gt == 0: if total_pred > 0: return torch.tensor(0.) @@ -233,7 +218,5 @@ def _accuracy(self, src_boxes,pred_boxes , iou_threshold): elif total_gt > 0 and total_pred == 0: return torch.tensor(0.) - - def worst(self, metrics): return minimum(metrics) From 94a8bdd37534d8db83883f78d33a636d050f86ec Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Thu, 29 Apr 2021 22:42:28 -0700 Subject: [PATCH 167/244] turn y into a list --- wilds/common/utils.py | 15 --------------- wilds/datasets/gwhd_dataset.py | 5 ++++- 2 files changed, 4 insertions(+), 16 deletions(-) diff --git a/wilds/common/utils.py b/wilds/common/utils.py index 47446b06..ebf2f41d 100644 --- a/wilds/common/utils.py +++ b/wilds/common/utils.py @@ -134,18 +134,3 @@ def numel(obj): return len(obj) else: raise TypeError("Invalid type for numel") - -# def get_subset_from_mask(seq, mask): -# """ -# Mask should be a binary vector with the same length as seq. -# """ -# if torch.is_tensor(seq) or isinstance(seq, list): -# if len(mask) != len(seq): -# print(len(mask)) -# print(len(seq)) -# raise ValueError('Mask must have same length as the input.') -# return seq[mask] -# elif isinstance(seq, dict): -# return {k: get_subset_from_mask(v, mask) for k, v in seq.items()} -# else: -# raise TypeError('Input must be a Tensor, list, or dict.') diff --git a/wilds/datasets/gwhd_dataset.py b/wilds/datasets/gwhd_dataset.py index 1d3622b3..221b480a 100644 --- a/wilds/datasets/gwhd_dataset.py +++ b/wilds/datasets/gwhd_dataset.py @@ -11,10 +11,13 @@ def _collate_fn(batch): """ Stack x (batch[0]) and metadata (batch[2]), but not y. + originally, batch = (item1, item2, item3, item4) + after zip, batch = [(item1[0], item2[0], ..), ..] """ batch = list(zip(*batch)) batch[0] = torch.stack(batch[0]) - batch[2] = torch.stack(batch[2]) + batch[1] = list(batch[1]) + batch[2] = torch.stack(batch[2]) return tuple(batch) class GWHDDataset(WILDSDataset): From 1c0e8b787160a224c350f712ec54b182f24c6986 Mon Sep 17 00:00:00 2001 From: aikanor Date: Fri, 30 Apr 2021 16:16:16 -0700 Subject: [PATCH 168/244] change DNase normalization name --- wilds/datasets/encodetfbs_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index 7395d310..6bcc03a7 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -358,11 +358,11 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' print(chrom, time.time() - itime) del seq_arr - # Set up file handles for DNase features, writing normalized DNase tracks along the way. + # Set up file handles for DNase features, writing normalized DNase tracks along the way if they aren't already written. self._dnase_allcelltypes = {} for ct in self._all_celltypes: orig_dnase_bw_path = os.path.join(self._data_dir, 'DNASE.{}.fc.signal.bigwig'.format(ct)) - dnase_bw_path = os.path.join(self._data_dir, 'DNase.{}.{}.bigwig'.format(ct, self._split_scheme)) + dnase_bw_path = os.path.join(self._data_dir, 'DNase.{}.{}.{}.bigwig'.format(self._transcription_factor, ct, self._split_scheme)) if not os.path.exists(dnase_bw_path): ref_celltypes = splits['train']['celltypes'] dnase_normalize(ct, ref_celltypes, out_fname=self._split_scheme, data_pfx=self._data_dir) From 15acda0e4068671adaf51bfab607a153292dc1be Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Mon, 3 May 2021 14:44:42 -0700 Subject: [PATCH 169/244] fix mse case sensitivity --- examples/losses.py | 2 +- examples/run_expt.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/losses.py b/examples/losses.py index cfd0789e..2e07318e 100644 --- a/examples/losses.py +++ b/examples/losses.py @@ -8,7 +8,7 @@ def initialize_loss(config, d_out): elif config.loss_function == 'lm_cross_entropy': return MultiTaskLoss(loss_fn=nn.CrossEntropyLoss(reduction='none')) - elif config.loss_function == 'MSE': + elif config.loss_function == 'mse': return MSE(name='loss') elif config.loss_function == 'multitask_bce': diff --git a/examples/run_expt.py b/examples/run_expt.py index ff0ba47d..19f6d570 100644 --- a/examples/run_expt.py +++ b/examples/run_expt.py @@ -8,6 +8,7 @@ import sys from collections import defaultdict +sys.path.insert(1, os.path.join(sys.path[0], '..')) import wilds from wilds.common.data_loaders import get_train_loader, get_eval_loader from wilds.common.grouper import CombinatorialGrouper From 9e7a655d13af22c9baa18f5237a26af7cca3d232 Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Mon, 3 May 2021 21:28:00 -0700 Subject: [PATCH 170/244] fix needs_y for model tuples and include for MSE --- examples/losses.py | 1 + examples/models/initializer.py | 7 ++++++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/examples/losses.py b/examples/losses.py index 2e07318e..e969a10e 100644 --- a/examples/losses.py +++ b/examples/losses.py @@ -1,5 +1,6 @@ import torch.nn as nn from wilds.common.metrics.loss import ElementwiseLoss, Loss, MultiTaskLoss +from wilds.common.metrics.all_metrics import MSE def initialize_loss(config, d_out): if config.loss_function == 'cross_entropy': diff --git a/examples/models/initializer.py b/examples/models/initializer.py index e38aede6..a9580c93 100644 --- a/examples/models/initializer.py +++ b/examples/models/initializer.py @@ -93,7 +93,12 @@ def initialize_model(config, d_out, is_featurizer=False): # If True, Algorithm.process_batch() will call model(x, y) during training, # and model(x, None) during eval. if not hasattr(model, 'needs_y'): - model.needs_y = False + # Sometimes model is a tuple of (featurizer, classifier) + if isinstance(model, tuple): + for submodel in model: + submodel.needs_y = False + else: + model.needs_y = False return model From 2a0832ccb5e6bb99157e770273b45ed5a0d3a36c Mon Sep 17 00:00:00 2001 From: aikanor Date: Fri, 7 May 2021 15:43:20 -0700 Subject: [PATCH 171/244] Fix DNase file naming --- wilds/datasets/encodetfbs_dataset.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index 6bcc03a7..3cac8b48 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -82,7 +82,6 @@ def dnase_normalize( ref += (1.0/len(ref_celltypes))*np.load(data_pfx + "qn.{}.npy".format(ct)) chromsizes_list = [(k, v) for k, v in chrom_sizes.items()] - out_fname = data_pfx + 'DNase.{}.{}.bigwig'.format(input_bw_celltype, out_fname) bw_output = pyBigWig.open(out_fname, 'w') bw_output.addHeader(chromsizes_list) # bw_output.addHeader(list(zip(chr_all , num_bp)), maxZooms=0) # zip two turples @@ -251,7 +250,7 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' 'test': 'Test', } - # Add new split scheme specifying custom test and val celltypes in the format val..test., e.g. 'official' is 'tf.MAX.val.HepG2.test.liver' + # Add new split scheme specifying custom test and val celltypes in the format val..test., e.g. self._split_scheme == 'official' is equivalent to self._split_scheme == 'val.HepG2.test.liver' elif '.' in self._split_scheme: all_celltypes = train_celltypes + val_celltype + test_celltype in_val_ct = self._split_scheme.split('.')[1] @@ -365,7 +364,7 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' dnase_bw_path = os.path.join(self._data_dir, 'DNase.{}.{}.{}.bigwig'.format(self._transcription_factor, ct, self._split_scheme)) if not os.path.exists(dnase_bw_path): ref_celltypes = splits['train']['celltypes'] - dnase_normalize(ct, ref_celltypes, out_fname=self._split_scheme, data_pfx=self._data_dir) + dnase_normalize(ct, ref_celltypes, out_fname=dnase_bw_path, data_pfx=self._data_dir) self._dnase_allcelltypes[ct] = pyBigWig.open(dnase_bw_path) # Load subsampled DNase arrays for normalization purposes From bff90552c4c146e1dd83a7bf28510134687ca816 Mon Sep 17 00:00:00 2001 From: Etienne David Date: Sat, 8 May 2021 12:30:43 +0200 Subject: [PATCH 172/244] Revert "GWHD and FasterRCNN support" --- .gitignore | 8 - .vscode/settings.json | 3 - examples/algorithms/algorithm.py | 15 +- examples/algorithms/group_algorithm.py | 4 +- examples/algorithms/initializer.py | 29 +- examples/algorithms/single_model_algorithm.py | 17 +- examples/configs/datasets.py | 26 - examples/configs/model.py | 11 +- examples/configs/supported.py | 30 +- examples/losses.py | 23 - examples/models/detection/fasterrcnn.py | 511 ------------------ examples/models/initializer.py | 56 +- examples/run_expt.py | 22 +- examples/train.py | 45 +- examples/transforms.py | 18 +- examples/utils.py | 69 +-- setup.py | 1 - wilds/__init__.py | 1 - wilds/common/data_loaders.py | 17 +- wilds/common/metrics/all_metrics.py | 82 +-- wilds/common/metrics/metric.py | 5 +- wilds/common/utils.py | 12 +- wilds/datasets/gwhd_dataset.py | 146 ----- wilds/datasets/wilds_dataset.py | 20 +- wilds/get_dataset.py | 6 +- 25 files changed, 108 insertions(+), 1069 deletions(-) delete mode 100644 .vscode/settings.json delete mode 100644 examples/losses.py delete mode 100644 examples/models/detection/fasterrcnn.py delete mode 100644 wilds/datasets/gwhd_dataset.py diff --git a/.gitignore b/.gitignore index acf51ee9..1d3b5479 100644 --- a/.gitignore +++ b/.gitignore @@ -3,11 +3,3 @@ build dist venv wilds.egg-info -data -logs -test_faster -paper* -.vscode -*sh -*ipynb -experiences \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index d0c7592c..00000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "python.pythonPath": "/home/gdaubige/anaconda3/envs/wilds/bin/python" -} \ No newline at end of file diff --git a/examples/algorithms/algorithm.py b/examples/algorithms/algorithm.py index 5c734766..c93d960a 100644 --- a/examples/algorithms/algorithm.py +++ b/examples/algorithms/algorithm.py @@ -1,6 +1,5 @@ import torch import torch.nn as nn -from utils import move_to, detach_and_clone class Algorithm(nn.Module): def __init__(self, device): @@ -94,13 +93,19 @@ def sanitize_dict(self, in_dict, to_out_device=True): Helper function that sanitizes dictionaries by: - moving to the specified output device - removing any gradient information - - detaching and cloning the tensors + - turning any Tensor of size 1 to a simple number Args: - in_dict (dictionary) Output: - out_dict (dictionary): sanitized version of in_dict """ - out_dict = detach_and_clone(in_dict) - if to_out_device: - out_dict = move_to(out_dict, self.out_device) + out_dict = {} + for k, v in in_dict.items(): + if isinstance(v, torch.Tensor): + v_out = v.detach().clone() + if to_out_device: + v_out = v_out.to(self.out_device) + else: + v_out = v + out_dict[k] = v_out return out_dict diff --git a/examples/algorithms/group_algorithm.py b/examples/algorithms/group_algorithm.py index eb0b95c2..54cac1a8 100644 --- a/examples/algorithms/group_algorithm.py +++ b/examples/algorithms/group_algorithm.py @@ -3,7 +3,7 @@ from algorithms.algorithm import Algorithm from utils import update_average from scheduler import step_scheduler -from wilds.common.utils import get_counts, numel +from wilds.common.utils import get_counts class GroupAlgorithm(Algorithm): """ @@ -57,7 +57,7 @@ def update_log(self, results): results['y_pred'], results['y_true'], return_dict=False).item() - count = numel(results['y_true']) + count = results['y_true'].numel() # transfer other statistics in the results dictionary for field in self.logged_fields: diff --git a/examples/algorithms/initializer.py b/examples/algorithms/initializer.py index 180e9ff5..00748cfc 100644 --- a/examples/algorithms/initializer.py +++ b/examples/algorithms/initializer.py @@ -3,8 +3,7 @@ from algorithms.groupDRO import GroupDRO from algorithms.deepCORAL import DeepCORAL from algorithms.IRM import IRM -from configs.supported import algo_log_metrics -from losses import initialize_loss +from configs.supported import algo_log_metrics, losses def initialize_algorithm(config, datasets, train_grouper): train_dataset = datasets['train']['dataset'] @@ -12,27 +11,23 @@ def initialize_algorithm(config, datasets, train_grouper): # Configure the final layer of the networks used # The code below are defaults. Edit this if you need special config for your model. - if train_dataset.is_classification: - if train_dataset.y_size == 1: - # For single-task classification, we have one output per class - d_out = train_dataset.n_classes - elif train_dataset.y_size is None: - d_out = train_dataset.n_classes - elif (train_dataset.y_size > 1) and (train_dataset.n_classes == 2): - # For multi-task binary classification (each output is the logit for each binary class) - d_out = train_dataset.y_size - else: - raise RuntimeError('d_out not defined.') - elif train_dataset.is_detection: - # For detection, d_out is the number of classes + if (train_dataset.is_classification) and (train_dataset.y_size == 1): + # For single-task classification, we have one output per class d_out = train_dataset.n_classes - else: + elif (train_dataset.is_classification) and (train_dataset.y_size is None): + d_out = train_dataset.n_classes + elif (train_dataset.is_classification) and (train_dataset.y_size > 1) and (train_dataset.n_classes == 2): + # For multi-task binary classification (each output is the logit for each binary class) + d_out = train_dataset.y_size + elif (not train_dataset.is_classification): # For regression, we have one output per target dimension d_out = train_dataset.y_size + else: + raise RuntimeError('d_out not defined.') # Other config n_train_steps = len(train_loader) * config.n_epochs - loss = initialize_loss(config, d_out) + loss = losses[config.loss_function] metric = algo_log_metrics[config.algo_log_metric] if config.algorithm=='ERM': diff --git a/examples/algorithms/single_model_algorithm.py b/examples/algorithms/single_model_algorithm.py index f01c21bb..e368b88f 100644 --- a/examples/algorithms/single_model_algorithm.py +++ b/examples/algorithms/single_model_algorithm.py @@ -3,7 +3,6 @@ from scheduler import initialize_scheduler from optimizer import initialize_optimizer from torch.nn.utils import clip_grad_norm_ -from utils import move_to class SingleModelAlgorithm(GroupAlgorithm): """ @@ -48,19 +47,11 @@ def process_batch(self, batch): - y_true """ x, y_true, metadata = batch - x = move_to(x, self.device) - y_true = move_to(y_true, self.device) - g = move_to(self.grouper.metadata_to_group(metadata), self.device) + x = x.to(self.device) + y_true = y_true.to(self.device) + g = self.grouper.metadata_to_group(metadata).to(self.device) + outputs = self.model(x) - - if self.model.needs_y: - if self.training: - outputs = self.model(x, y_true) - else: - outputs = self.model(x, None) - else: - outputs = self.model(x) - results = { 'g': g, 'y_true': y_true, diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index 17bb8640..cd2d1d6f 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -282,32 +282,6 @@ 'n_epochs': 4, 'process_outputs_function': None, }, - 'gwhd': { - 'split_scheme': 'official', - 'model': 'fasterrcnn', - 'train_transform': 'image_base', - 'eval_transform': 'image_base', - 'model_kwargs': { - 'n_classes': 1, - 'pretrained': True}, - 'loss_function': 'fasterrcnn_criterion', - 'groupby_fields': ['location'], - 'val_metric': 'detection_acc_avg', # TODO - 'val_metric_decreasing': False, - 'algo_log_metric': None, # TODO - 'optimizer': 'Adam', - 'optimizer_kwargs': {}, - 'scheduler': None, - 'batch_size': 4, - 'lr': 1e-5, - 'weight_decay': 1e-4, - 'n_epochs': 10, - 'loader_kwargs': { - 'num_workers': 1, - 'pin_memory': True, - }, - 'process_outputs_function': None, - } } ########################################## diff --git a/examples/configs/model.py b/examples/configs/model.py index 37afbf22..46714bbe 100644 --- a/examples/configs/model.py +++ b/examples/configs/model.py @@ -15,19 +15,19 @@ 'scheduler': 'linear_schedule_with_warmup', }, 'densenet121': { - 'model_kwargs': { + 'model_kwargs':{ 'pretrained':True, }, 'target_resolution': (224, 224), }, 'wideresnet50': { - 'model_kwargs': { + 'model_kwargs':{ 'pretrained':True, }, 'target_resolution': (224, 224), }, 'resnet50': { - 'model_kwargs': { + 'model_kwargs':{ 'pretrained':True, }, 'target_resolution': (224, 224), @@ -37,9 +37,4 @@ 'target_resolution': (224, 224), }, 'logistic_regression': {}, - 'fasterrcnn': { - 'model_kwargs': { - 'pretrained': True, - } - } } diff --git a/examples/configs/supported.py b/examples/configs/supported.py index b0fd9aab..8b66b74e 100644 --- a/examples/configs/supported.py +++ b/examples/configs/supported.py @@ -1,6 +1,18 @@ +import torch.nn as nn +import torch +import sys, os + # metrics +from wilds.common.metrics.loss import ElementwiseLoss, Loss, MultiTaskLoss from wilds.common.metrics.all_metrics import Accuracy, MultiTaskAccuracy, MSE, multiclass_logits_to_pred, binary_logits_to_pred +losses = { + 'cross_entropy': ElementwiseLoss(loss_fn=nn.CrossEntropyLoss(reduction='none')), + 'lm_cross_entropy': MultiTaskLoss(loss_fn=nn.CrossEntropyLoss(reduction='none')), + 'mse': MSE(name='loss'), + 'multitask_bce': MultiTaskLoss(loss_fn=nn.BCEWithLogitsLoss(reduction='none')), +} + algo_log_metrics = { 'accuracy': Accuracy(prediction_fn=multiclass_logits_to_pred), 'mse': MSE(), @@ -15,23 +27,11 @@ None: None, } -# See models/initializer.py +# see initialize_*() functions for correspondence +transforms = ['bert', 'image_base', 'image_resize_and_center_crop', 'poverty_train'] models = ['resnet18_ms', 'resnet50', 'resnet34', 'wideresnet50', 'densenet121', 'bert-base-uncased', 'distilbert-base-uncased', - 'gin-virtual', 'logistic_regression', 'code-gpt-py', - 'fasterrcnn'] - -# See algorithms/initializer.py + 'gin-virtual', 'logistic_regression', 'code-gpt-py'] algorithms = ['ERM', 'groupDRO', 'deepCORAL', 'IRM'] - -# See optimizer.py optimizers = ['SGD', 'Adam', 'AdamW'] - -# See scheduler.py schedulers = ['linear_schedule_with_warmup', 'ReduceLROnPlateau', 'StepLR'] - -# See transforms.py -transforms = ['bert', 'image_base', 'image_resize_and_center_crop', 'poverty_train'] - -# See losses.py -losses = ['cross_entropy', 'lm_cross_entropy', 'MSE', 'multitask_bce', 'fasterrcnn_criterion'] diff --git a/examples/losses.py b/examples/losses.py deleted file mode 100644 index e969a10e..00000000 --- a/examples/losses.py +++ /dev/null @@ -1,23 +0,0 @@ -import torch.nn as nn -from wilds.common.metrics.loss import ElementwiseLoss, Loss, MultiTaskLoss -from wilds.common.metrics.all_metrics import MSE - -def initialize_loss(config, d_out): - if config.loss_function == 'cross_entropy': - return ElementwiseLoss(loss_fn=nn.CrossEntropyLoss(reduction='none')) - - elif config.loss_function == 'lm_cross_entropy': - return MultiTaskLoss(loss_fn=nn.CrossEntropyLoss(reduction='none')) - - elif config.loss_function == 'mse': - return MSE(name='loss') - - elif config.loss_function == 'multitask_bce': - return MultiTaskLoss(loss_fn=nn.BCEWithLogitsLoss(reduction='none')) - - elif config.loss_function == 'fasterrcnn_criterion': - from examples.models.detection.fasterrcnn import FasterRCNNLoss - return ElementwiseLoss(loss_fn=FasterRCNNLoss(config.device)) - - else: - raise ValueError(f'config.loss_function {config.loss_function} not recognized') diff --git a/examples/models/detection/fasterrcnn.py b/examples/models/detection/fasterrcnn.py deleted file mode 100644 index d553ea0f..00000000 --- a/examples/models/detection/fasterrcnn.py +++ /dev/null @@ -1,511 +0,0 @@ -import torch -import torch.nn as nn -import torchvision -from collections import OrderedDict -import torch -from torch import nn, Tensor -import warnings -from typing import Tuple, List, Dict, Optional, Union - -from torch import nn - - -import torchvision -from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, FasterRCNN -from torchvision.models.detection.backbone_utils import resnet_fpn_backbone -from torchvision.models.utils import load_state_dict_from_url - - -from torchvision.ops import misc as misc_nn_ops -from torchvision.ops import MultiScaleRoIAlign - - -from torchvision.models.detection.anchor_utils import AnchorGenerator -from torchvision.models.detection.generalized_rcnn import GeneralizedRCNN -from torchvision.models.detection.faster_rcnn import TwoMLPHead - -from torchvision.models.detection.rpn import RPNHead, RegionProposalNetwork, concat_box_prediction_layers,permute_and_flatten -from torchvision.models.detection.roi_heads import RoIHeads - -from torchvision.models.detection import _utils as det_utils -from torch.nn import functional as F -from torchvision.models.detection.transform import GeneralizedRCNNTransform - - -model_urls = { - 'fasterrcnn_resnet50_fpn_coco': - 'https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth', - 'fasterrcnn_mobilenet_v3_large_320_fpn_coco': - 'https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth', - 'fasterrcnn_mobilenet_v3_large_fpn_coco': - 'https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth' -} - - -def batch_concat_box_prediction_layers(box_cls, box_regression): - # type: (List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor] - box_cls_flattened = [] - box_regression_flattened = [] - # for each feature level, permute the outputs to make them be in the - # same format as the labels. Note that the labels are computed for - # all feature levels concatenated, so we keep the same representation - # for the objectness and the box_regression - for box_cls_per_level, box_regression_per_level in zip( - box_cls, box_regression - ): - N, AxC, H, W = box_cls_per_level.shape - Ax4 = box_regression_per_level.shape[1] - A = Ax4 // 4 - C = AxC // A - box_cls_per_level = permute_and_flatten( - box_cls_per_level, N, A, C, H, W - ) - box_cls_flattened.append(box_cls_per_level) - - box_regression_per_level = permute_and_flatten( - box_regression_per_level, N, A, 4, H, W - ) - box_regression_flattened.append(box_regression_per_level) - # concatenate on the first dimension (representing the feature levels), to - # take into account the way the labels were generated (with all feature maps - # being concatenated as well) - - batch_size = box_regression_flattened[0].shape[0] - - new_box_cls = [] - new_box_regression = [] - for batch_idx in range(batch_size): - element_box_cls = [torch.unsqueeze(item[batch_idx],dim=0) for item in box_cls_flattened] - element_box_regression = [torch.unsqueeze(item[batch_idx],dim=0) for item in box_regression_flattened] - - element_box_cls = torch.cat(element_box_cls, dim=1).flatten(0, -2) - element_box_regression = torch.cat(element_box_regression, dim=1).reshape(-1, 4) - new_box_cls.append(element_box_cls) - new_box_regression.append(element_box_regression) - - - return new_box_cls, new_box_regression - -class RegionProposalNetworkWILDS(RegionProposalNetwork): - def __init__(self, - anchor_generator, - head, - # - fg_iou_thresh, bg_iou_thresh, - batch_size_per_image, positive_fraction, - # - pre_nms_top_n, post_nms_top_n, nms_thresh): - super().__init__(anchor_generator, - head, - fg_iou_thresh, bg_iou_thresh, - batch_size_per_image, positive_fraction, - pre_nms_top_n, post_nms_top_n, nms_thresh) - - def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets): - # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor] - """ - Arguments: - objectness (Tensor) - pred_bbox_deltas (Tensor) - labels (List[Tensor]) - regression_targets (List[Tensor]) - Returns: - objectness_loss (Tensor) - box_loss (Tensor) - """ - objectness, pred_bbox_deltas = batch_concat_box_prediction_layers(objectness, pred_bbox_deltas) - - objectness_loss = [] - box_loss = [] - - for objectness_, regression_targets_,labels_,objectness_,pred_bbox_deltas_ in zip(objectness,regression_targets,labels,objectness,pred_bbox_deltas): - - sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(torch.unsqueeze(labels_,dim=0)) - - sampled_pos_inds = torch.where(torch.cat(sampled_pos_inds, dim=0))[0] - sampled_neg_inds = torch.where(torch.cat(sampled_neg_inds, dim=0))[0] - sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0) - - box_loss.append(det_utils.smooth_l1_loss( - pred_bbox_deltas_[sampled_pos_inds], - regression_targets_[sampled_pos_inds], - beta=1 / 9, - size_average=False, - ) / (sampled_inds.numel())) - - - objectness_loss.append(F.binary_cross_entropy_with_logits( - objectness_[sampled_inds].flatten(), labels_[sampled_inds] - )) - - return torch.stack(objectness_loss), torch.stack(box_loss) - - def forward(self, - images, # type: ImageList - features, # type: Dict[str, Tensor] - targets=None # type: Optional[List[Dict[str, Tensor]]] - ): - # type: (...) -> Tuple[List[Tensor], Dict[str, Tensor]] - """ - Arguments: - images (ImageList): images for which we want to compute the predictions - features (OrderedDict[Tensor]): features computed from the images that are - used for computing the predictions. Each tensor in the list - correspond to different feature levels - targets (List[Dict[Tensor]]): ground-truth boxes present in the image (optional). - If provided, each element in the dict should contain a field `boxes`, - with the locations of the ground-truth boxes. - Returns: - boxes (List[Tensor]): the predicted boxes from the RPN, one Tensor per - image. - losses (Dict[Tensor]): the losses for the model during training. During - testing, it is an empty dict. - """ - # RPN uses all feature maps that are available - features = list(features.values()) - objectness, pred_bbox_deltas = self.head(features) - anchors = self.anchor_generator(images, features) - - num_images = len(anchors) - num_anchors_per_level_shape_tensors = [o[0].shape for o in objectness] - num_anchors_per_level = [s[0] * s[1] * s[2] for s in num_anchors_per_level_shape_tensors] - - raw_objectness = objectness - raw_pred_bbox_deltas = pred_bbox_deltas - objectness, pred_bbox_deltas = \ - concat_box_prediction_layers(objectness, pred_bbox_deltas) - # apply pred_bbox_deltas to anchors to obtain the decoded proposals - # note that we detach the deltas because Faster R-CNN do not backprop through - # the proposals - proposals = self.box_coder.decode(pred_bbox_deltas.detach(), anchors) - proposals = proposals.view(num_images, -1, 4) - - boxes, scores = self.filter_proposals(proposals, objectness, images.image_sizes, num_anchors_per_level) - losses = {} - - if self.training: - assert targets is not None - labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets) - regression_targets = self.box_coder.encode(matched_gt_boxes, anchors) - loss_objectness, loss_rpn_box_reg = self.compute_loss( - raw_objectness, raw_pred_bbox_deltas, labels, regression_targets) - - losses = { - "loss_objectness": loss_objectness, - "loss_rpn_box_reg": loss_rpn_box_reg, - } - return boxes, losses - -def fastrcnn_loss(class_logits, box_regression, labels, regression_targets): - # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor] - """ - Computes the loss for Faster R-CNN. - Arguments: - class_logits (Tensor) - box_regression (Tensor) - labels (list[BoxList]) - regression_targets (Tensor) - Returns: - classification_loss (Tensor) - box_loss (Tensor) - """ - - class_logits = torch.split(class_logits, 512,dim=0) - box_regression = torch.split(box_regression, 512,dim=0) - - classification_loss = [] - box_loss = [] - - for class_logits_, box_regression_, labels_, regression_targets_ in zip(class_logits, box_regression, labels, regression_targets): - classification_loss.append(F.cross_entropy(class_logits_, labels_)) - # get indices that correspond to the regression targets for - # the corresponding ground truth labels, to be used with - # advanced indexing - sampled_pos_inds_subset = torch.where(labels_ > 0)[0] - - labels_pos = labels_[sampled_pos_inds_subset] - N, num_classes = class_logits_.shape - - box_regression_ = box_regression_.reshape(N, -1, 4) - - box_loss_ = det_utils.smooth_l1_loss( - box_regression_[sampled_pos_inds_subset, labels_pos], - regression_targets_[sampled_pos_inds_subset], - beta=1 / 9, - size_average=False, - ) - box_loss.append(box_loss_ / labels_.numel()) - - return torch.stack(classification_loss), torch.stack(box_loss) - -class RoIHeadsWILDS(RoIHeads): - def __init__(self, box_roi_pool, box_head, box_predictor, box_fg_iou_thresh, box_bg_iou_thresh,box_batch_size_per_image,box_positive_fraction,bbox_reg_weights,box_score_thresh,box_nms_thresh,box_detections_per_img): - - super().__init__(box_roi_pool, box_head, box_predictor, - box_fg_iou_thresh, box_bg_iou_thresh, - box_batch_size_per_image, box_positive_fraction, - bbox_reg_weights, - box_score_thresh, box_nms_thresh, box_detections_per_img) - - def forward(self, - features, # type: Dict[str, Tensor] - proposals, # type: List[Tensor] - image_shapes, # type: List[Tuple[int, int]] - targets=None # type: Optional[List[Dict[str, Tensor]]] - ): - # type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]] - """ - Arguments: - features (List[Tensor]) - proposals (List[Tensor[N, 4]]) - image_shapes (List[Tuple[H, W]]) - targets (List[Dict]) - """ - if targets is not None: - for t in targets: - # TODO: https://github.com/pytorch/pytorch/issues/26731 - floating_point_types = (torch.float, torch.double, torch.half) - assert t["boxes"].dtype in floating_point_types, 'target boxes must of float type' - assert t["labels"].dtype == torch.int64, 'target labels must of int64 type' - if self.has_keypoint(): - assert t["keypoints"].dtype == torch.float32, 'target keypoints must of float type' - - # here batch is maintained - if self.training: - proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets) - else: - labels = None - regression_targets = None - matched_idxs = None - - box_features = self.box_roi_pool(features, proposals, image_shapes) - - box_features = self.box_head(box_features) - - class_logits, box_regression = self.box_predictor(box_features) - result = torch.jit.annotate(List[Dict[str, torch.Tensor]], []) - losses = {} - - if self.training: - assert labels is not None and regression_targets is not None - - loss_classifier, loss_box_reg = fastrcnn_loss( - class_logits, box_regression, labels, regression_targets) - losses = { - "loss_classifier": loss_classifier, - "loss_box_reg": loss_box_reg - } - - boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes) - num_images = len(boxes) - for i in range(num_images): - result.append( - { - "boxes": boxes[i], - "labels": labels[i], - "scores": scores[i], - } - ) - - return result, losses - -def fasterrcnn_resnet50_fpn(pretrained=False, progress=True, - num_classes=91, pretrained_backbone=True, trainable_backbone_layers=3, **kwargs): - - assert trainable_backbone_layers <= 5 and trainable_backbone_layers >= 0 - # dont freeze any layers if pretrained model or backbone is not used - if not (pretrained or pretrained_backbone): - trainable_backbone_layers = 5 - if pretrained: - # no need to download the backbone if pretrained is set - pretrained_backbone = False - backbone = resnet_fpn_backbone('resnet50', pretrained_backbone, trainable_layers=trainable_backbone_layers) - model = FastWILDS(backbone, 91, **kwargs) - if pretrained: - state_dict = load_state_dict_from_url(model_urls['fasterrcnn_resnet50_fpn_coco'], - progress=progress) - model.load_state_dict(state_dict) - - # get number of input features for the classifier - in_features = model.roi_heads.box_predictor.cls_score.in_features - # replace the pre-trained head with a new one - model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes+1) - - return model - -class FastWILDS(GeneralizedRCNN): - def __init__(self, backbone, num_classes=None, - # transform parameters - min_size=800, max_size=1333, - image_mean=None, image_std=None, - # RPN parameters - rpn_anchor_generator=None, rpn_head=None, - rpn_pre_nms_top_n_train=2000, rpn_pre_nms_top_n_test=1000, - rpn_post_nms_top_n_train=2000, rpn_post_nms_top_n_test=1000, - rpn_nms_thresh=0.7, - rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3, - rpn_batch_size_per_image=256, rpn_positive_fraction=0.5, - # Box parameters - box_roi_pool=None, box_head=None, box_predictor=None, - box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100, - box_fg_iou_thresh=0.5, box_bg_iou_thresh=0.5, - box_batch_size_per_image=512, box_positive_fraction=0.25, - bbox_reg_weights=None): - - if not hasattr(backbone, "out_channels"): - raise ValueError( - "backbone should contain an attribute out_channels " - "specifying the number of output channels (assumed to be the " - "same for all the levels)") - - assert isinstance(rpn_anchor_generator, (AnchorGenerator, type(None))) - assert isinstance(box_roi_pool, (MultiScaleRoIAlign, type(None))) - - if num_classes is not None: - if box_predictor is not None: - raise ValueError("num_classes should be None when box_predictor is specified") - else: - if box_predictor is None: - raise ValueError("num_classes should not be None when box_predictor " - "is not specified") - - out_channels = backbone.out_channels - - if rpn_anchor_generator is None: - anchor_sizes = ((32,), (64,), (128,), (256,), (512,)) - aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) - rpn_anchor_generator = AnchorGenerator( - anchor_sizes, aspect_ratios - ) - if rpn_head is None: - rpn_head = RPNHead( - out_channels, rpn_anchor_generator.num_anchors_per_location()[0] - ) - - rpn_pre_nms_top_n = dict(training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test) - rpn_post_nms_top_n = dict(training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test) - - rpn = RegionProposalNetworkWILDS( - rpn_anchor_generator, rpn_head, - rpn_fg_iou_thresh, rpn_bg_iou_thresh, - rpn_batch_size_per_image, rpn_positive_fraction, - rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh) - - if box_roi_pool is None: - box_roi_pool = MultiScaleRoIAlign( - featmap_names=['0', '1', '2', '3'], - output_size=7, - sampling_ratio=2) - - if box_head is None: - resolution = box_roi_pool.output_size[0] - representation_size = 1024 - box_head = TwoMLPHead( - out_channels * resolution ** 2, - representation_size) - - if box_predictor is None: - representation_size = 1024 - box_predictor = FastRCNNPredictor( - representation_size, - num_classes) - - roi_heads = RoIHeadsWILDS( - box_roi_pool, box_head, box_predictor, - box_fg_iou_thresh, box_bg_iou_thresh, - box_batch_size_per_image, box_positive_fraction, - bbox_reg_weights, - box_score_thresh, box_nms_thresh, box_detections_per_img) - - - image_mean = [0., 0., 0.] # small trick because images are already normalized - image_std = [1., 1., 1.] - transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std) - - super(FastWILDS, self).__init__(backbone, rpn, roi_heads, transform) - # Set your own forward pass - def forward(self, images, targets=None): - - - if self.training: - if targets is None: - raise ValueError("In training mode, targets should be passed") - assert targets is not None - for target in targets: - boxes = target["boxes"] - if isinstance(boxes, torch.Tensor): - if len(boxes.shape) != 2 or boxes.shape[-1] != 4: - raise ValueError("Expected target boxes to be a tensor" - "of shape [N, 4], got {:}.".format( - boxes.shape)) - else: - raise ValueError("Expected target boxes to be of type " - "Tensor, got {:}.".format(type(boxes))) - - original_image_sizes: List[Tuple[int, int]] = [] - for img in images: - val = img.shape[-2:] - assert len(val) == 2 - original_image_sizes.append((val[0], val[1])) - - - images, targets = self.transform(images, targets) - - # Check for degenerate boxes - # TODO: Move this to a function - if targets is not None: - for target_idx, target in enumerate(targets): - boxes = target["boxes"] - degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] - if degenerate_boxes.any(): - # print the first degenerate box - bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0] - degen_bb: List[float] = boxes[bb_idx].tolist() - raise ValueError("All bounding boxes should have positive height and width." - " Found invalid box {} for target at index {}." - .format(degen_bb, target_idx)) - - features = self.backbone(images.tensors) - if isinstance(features, torch.Tensor): - features = OrderedDict([('0', features)]) - - proposals, proposal_losses = self.rpn(images, features, targets) - - - detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets) - - - detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes) - - for idx, det in enumerate(detections): - det["losses"] = {} - for k,v in proposal_losses.items(): - det["losses"][k] = v[idx] - for k,v in detector_losses.items(): - det["losses"][k] = v[idx] - - - return detections - - - - - - -class FasterRCNNLoss(nn.Module): - def __init__(self,device): - self.device = device - super().__init__() - - def forward(self, outputs, targets): - - - # loss values are loss_classifier loss_box_reg loss_objectness": loss_objectness, loss_rpn_box_reg - try: - elementwise_loss = torch.stack([sum(v for v in item["losses"].values()) for item in outputs]) - except: - elementwise_loss = torch.ones(len(outputs)).to(self.device) - - - - return elementwise_loss diff --git a/examples/models/initializer.py b/examples/models/initializer.py index a9580c93..4d414763 100644 --- a/examples/models/initializer.py +++ b/examples/models/initializer.py @@ -1,7 +1,12 @@ -import torch import torch.nn as nn - +import torchvision +from models.bert.bert import BertClassifier, BertFeaturizer +from models.bert.distilbert import DistilBertClassifier, DistilBertFeaturizer +from models.resnet_multispectral import ResNet18 from models.layers import Identity +from models.gnn import GINVirtual +from models.code_gpt import GPT2LMHeadLogit, GPT2FeaturizerLMHeadLogit +from transformers import GPT2Tokenizer def initialize_model(config, d_out, is_featurizer=False): """ @@ -31,7 +36,6 @@ def initialize_model(config, d_out, is_featurizer=False): name=config.model, d_out=d_out, **config.model_kwargs) - elif 'bert' in config.model: if is_featurizer: featurizer = initialize_bert_based_model(config, d_out, is_featurizer) @@ -39,28 +43,21 @@ def initialize_model(config, d_out, is_featurizer=False): model = (featurizer, classifier) else: model = initialize_bert_based_model(config, d_out) - elif config.model == 'resnet18_ms': # multispectral resnet 18 - from models.resnet_multispectral import ResNet18 if is_featurizer: featurizer = ResNet18(num_classes=None, **config.model_kwargs) classifier = nn.Linear(featurizer.d_out, d_out) model = (featurizer, classifier) else: model = ResNet18(num_classes=d_out, **config.model_kwargs) - elif config.model == 'gin-virtual': - from models.gnn import GINVirtual if is_featurizer: featurizer = GINVirtual(num_tasks=None, **config.model_kwargs) classifier = nn.Linear(featurizer.d_out, d_out) model = (featurizer, classifier) else: model = GINVirtual(num_tasks=d_out, **config.model_kwargs) - elif config.model == 'code-gpt-py': - from models.code_gpt import GPT2LMHeadLogit, GPT2FeaturizerLMHeadLogit - from transformers import GPT2Tokenizer name = 'microsoft/CodeGPT-small-py' tokenizer = GPT2Tokenizer.from_pretrained(name) if is_featurizer: @@ -72,41 +69,14 @@ def initialize_model(config, d_out, is_featurizer=False): else: model = GPT2LMHeadLogit.from_pretrained(name) model.resize_token_embeddings(len(tokenizer)) - elif config.model == 'logistic_regression': assert not is_featurizer, "Featurizer not supported for logistic regression" model = nn.Linear(out_features=d_out, **config.model_kwargs) - - elif config.model == 'fasterrcnn': - if is_featurizer: # TODO - raise NotImplementedError('Featurizer not implemented for detection yet') - else: - model = initialize_fasterrcnn_model(config, d_out) - model.needs_y = True - else: raise ValueError(f'Model: {config.model} not recognized.') - - # The `needs_y` attribute specifies whether the model's forward function - # needs to take in both (x, y). - # If False, Algorithm.process_batch will call model(x). - # If True, Algorithm.process_batch() will call model(x, y) during training, - # and model(x, None) during eval. - if not hasattr(model, 'needs_y'): - # Sometimes model is a tuple of (featurizer, classifier) - if isinstance(model, tuple): - for submodel in model: - submodel.needs_y = False - else: - model.needs_y = False - return model - def initialize_bert_based_model(config, d_out, is_featurizer=False): - from models.bert.bert import BertClassifier, BertFeaturizer - from models.bert.distilbert import DistilBertClassifier, DistilBertFeaturizer - if config.model == 'bert-base-uncased': if is_featurizer: model = BertFeaturizer.from_pretrained(config.model, **config.model_kwargs) @@ -128,8 +98,6 @@ def initialize_bert_based_model(config, d_out, is_featurizer=False): return model def initialize_torchvision_model(name, d_out, **kwargs): - import torchvision - # get constructor and last layer names if name == 'wideresnet50': constructor_name = 'wide_resnet50_2' @@ -155,13 +123,3 @@ def initialize_torchvision_model(name, d_out, **kwargs): model.d_out = d_out setattr(model, last_layer_name, last_layer) return model - - -def initialize_fasterrcnn_model(config, d_out): - - from models.detection.fasterrcnn import fasterrcnn_resnet50_fpn - - # load a model pre-trained pre-trained on COCO - model = fasterrcnn_resnet50_fpn(pretrained=config.model_kwargs["pretrained"],num_classes=d_out) - - return model diff --git a/examples/run_expt.py b/examples/run_expt.py index 19f6d570..173603ab 100644 --- a/examples/run_expt.py +++ b/examples/run_expt.py @@ -8,7 +8,6 @@ import sys from collections import defaultdict -sys.path.insert(1, os.path.join(sys.path[0], '..')) import wilds from wilds.common.data_loaders import get_train_loader, get_eval_loader from wilds.common.grouper import CombinatorialGrouper @@ -20,11 +19,8 @@ from configs.utils import populate_defaults import configs.supported as supported -import torch.multiprocessing - def main(): - - ''' to see default hyperparams for each dataset/model, look at configs/ ''' + ''' set default hyperparams in default_hyperparams.py ''' parser = argparse.ArgumentParser() # Required arguments @@ -65,8 +61,6 @@ def main(): # Objective parser.add_argument('--loss_function', choices = supported.losses) - parser.add_argument('--loss_kwargs', nargs='*', action=ParseKwargs, default={}, - help='keyword arguments for loss initialization passed as key1=value1 key2=value2') # Algorithm parser.add_argument('--groupby_fields', nargs='+') @@ -118,15 +112,10 @@ def main(): config = parser.parse_args() config = populate_defaults(config) - # For the GWHD dataset, we need to change the multiprocessing strategy or there will be - # too many open file descriptors - if config.dataset == 'gwhd': - torch.multiprocessing.set_sharing_strategy('file_system') - - # Set device + # set device config.device = torch.device("cuda:" + str(config.device)) if torch.cuda.is_available() else torch.device("cpu") - # Initialize logs + ## Initialize logs if os.path.exists(config.log_dir) and config.resume: resume=True mode='a' @@ -280,15 +269,12 @@ def main(): epoch = best_epoch else: epoch = config.eval_epoch - if epoch == best_epoch: - is_best = True evaluate( algorithm=algorithm, datasets=datasets, epoch=epoch, general_logger=logger, - config=config, - is_best=is_best) + config=config) logger.close() for split in datasets: diff --git a/examples/train.py b/examples/train.py index 93cc076e..774f3d1e 100644 --- a/examples/train.py +++ b/examples/train.py @@ -1,7 +1,8 @@ import os from tqdm import tqdm import torch -from utils import save_model, save_pred, get_pred_prefix, get_model_prefix, detach_and_clone, collate_list +from utils import save_model, save_pred, get_pred_prefix, get_model_prefix +import torch.autograd.profiler as profiler from configs.supported import process_outputs_functions def run_epoch(algorithm, dataset, general_logger, epoch, config, train): @@ -33,24 +34,23 @@ def run_epoch(algorithm, dataset, general_logger, epoch, config, train): # These tensors are already detached, but we need to clone them again # Otherwise they don't get garbage collected properly in some versions - # The extra detach is just for safety + # The subsequent detach is just for safety # (they should already be detached in batch_results) - epoch_y_true.append(detach_and_clone(batch_results['y_true'])) - y_pred = detach_and_clone(batch_results['y_pred']) + epoch_y_true.append(batch_results['y_true'].clone().detach()) + y_pred = batch_results['y_pred'].clone().detach() if config.process_outputs_function is not None: y_pred = process_outputs_functions[config.process_outputs_function](y_pred) epoch_y_pred.append(y_pred) - epoch_metadata.append(detach_and_clone(batch_results['metadata'])) + epoch_metadata.append(batch_results['metadata'].clone().detach()) if train and (batch_idx+1) % config.log_every==0: log_results(algorithm, dataset, general_logger, epoch, batch_idx) batch_idx += 1 - epoch_y_pred = collate_list(epoch_y_pred) - epoch_y_true = collate_list(epoch_y_true) - epoch_metadata = collate_list(epoch_metadata) - + epoch_y_pred = torch.cat(epoch_y_pred) + epoch_y_true = torch.cat(epoch_y_true) + epoch_metadata = torch.cat(epoch_metadata) results, results_str = dataset['dataset'].eval( epoch_y_pred, epoch_y_true, @@ -112,7 +112,7 @@ def train(algorithm, datasets, general_logger, config, epoch_offset, best_val_me general_logger.write('\n') -def evaluate(algorithm, datasets, epoch, general_logger, config, is_best): +def evaluate(algorithm, datasets, epoch, general_logger, config): algorithm.eval() for split, dataset in datasets.items(): if (not config.evaluate_all_splits) and (split not in config.eval_splits): @@ -123,20 +123,17 @@ def evaluate(algorithm, datasets, epoch, general_logger, config, is_best): iterator = tqdm(dataset['loader']) if config.progress_bar else dataset['loader'] for batch in iterator: batch_results = algorithm.evaluate(batch) - epoch_y_true.append(detach_and_clone(batch_results['y_true'])) - y_pred = detach_and_clone(batch_results['y_pred']) + epoch_y_true.append(batch_results['y_true'].clone().detach()) + y_pred = batch_results['y_pred'].clone().detach() if config.process_outputs_function is not None: y_pred = process_outputs_functions[config.process_outputs_function](y_pred) epoch_y_pred.append(y_pred) - epoch_metadata.append(detach_and_clone(batch_results['metadata'])) + epoch_metadata.append(batch_results['metadata'].clone().detach()) - epoch_y_pred = collate_list(epoch_y_pred) - epoch_y_true = collate_list(epoch_y_true) - epoch_metadata = collate_list(epoch_metadata) results, results_str = dataset['dataset'].eval( - epoch_y_pred, - epoch_y_true, - epoch_metadata) + torch.cat(epoch_y_pred), + torch.cat(epoch_y_true), + torch.cat(epoch_metadata)) results['epoch'] = epoch dataset['eval_logger'].log(results) @@ -145,7 +142,7 @@ def evaluate(algorithm, datasets, epoch, general_logger, config, is_best): # Skip saving train preds, since the train loader generally shuffles the data if split != 'train': - save_pred_if_needed(epoch_y_pred, dataset, epoch, config, is_best, force_save=True) + save_pred_if_needed(y_pred, dataset, epoch, config, is_best=False, force_save=True) def log_results(algorithm, dataset, general_logger, epoch, batch_idx): @@ -163,11 +160,11 @@ def save_pred_if_needed(y_pred, dataset, epoch, config, is_best, force_save=Fals if config.save_pred: prefix = get_pred_prefix(dataset, config) if force_save or (config.save_step is not None and (epoch + 1) % config.save_step == 0): - save_pred(y_pred, prefix + f'epoch:{epoch}_pred') - if (not force_save) and config.save_last: - save_pred(y_pred, prefix + f'epoch:last_pred') + save_pred(y_pred, prefix + f'epoch:{epoch}_pred.csv') + if config.save_last: + save_pred(y_pred, prefix + f'epoch:last_pred.csv') if config.save_best and is_best: - save_pred(y_pred, prefix + f'epoch:best_pred') + save_pred(y_pred, prefix + f'epoch:best_pred.csv') def save_model_if_needed(algorithm, dataset, epoch, config, is_best, best_val_metric): diff --git a/examples/transforms.py b/examples/transforms.py index bbcd88a4..bafbd42f 100644 --- a/examples/transforms.py +++ b/examples/transforms.py @@ -3,10 +3,6 @@ import torch def initialize_transform(transform_name, config, dataset): - """ - Transforms should take in a single (x, y) - and return (transformed_x, transformed_y). - """ if transform_name is None: return None elif transform_name=='bert': @@ -20,11 +16,6 @@ def initialize_transform(transform_name, config, dataset): else: raise ValueError(f"{transform_name} not recognized") -def transform_input_only(input_transform): - def transform(x, y): - return input_transform(x), y - return transform - def initialize_bert_transform(config): assert 'bert' in config.model assert config.max_token_length is not None @@ -50,7 +41,7 @@ def transform(text): dim=2) x = torch.squeeze(x, dim=0) # First shape dim is always 1 return x - return transform_input_only(transform) + return transform def getBertTokenizer(model): if model == 'bert-base-uncased': @@ -74,7 +65,7 @@ def initialize_image_base_transform(config, dataset): transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ] transform = transforms.Compose(transform_steps) - return transform_input_only(transform) + return transform def initialize_image_resize_and_center_crop_transform(config, dataset): """ @@ -93,7 +84,7 @@ def initialize_image_resize_and_center_crop_transform(config, dataset): transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) - return transform_input_only(transform) + return transform def initialize_poverty_train_transform(): transforms_ls = [ @@ -108,6 +99,5 @@ def transform_rgb(img): # bgr to rgb and back to bgr img[:3] = rgb_transform(img[:3][[2,1,0]])[[2,1,0]] return img - transform = transforms.Lambda(lambda x: transform_rgb(x)) - return transform_input_only(transform) + return transform diff --git a/examples/utils.py b/examples/utils.py index 73fa1b12..89780d62 100644 --- a/examples/utils.py +++ b/examples/utils.py @@ -176,16 +176,9 @@ def initialize_wandb(config): project=f"wilds") wandb.config.update(config) -def save_pred(y_pred, path_prefix): - # Single tensor - if torch.is_tensor(y_pred): - df = pd.DataFrame(y_pred.numpy()) - df.to_csv(path_prefix + '.csv', index=False, header=False) - # Dictionary - elif isinstance(y_pred, dict) or isinstance(y_pred, list): - torch.save(y_pred, path_prefix + '.pth') - else: - raise TypeError("Invalid type for save_pred") +def save_pred(y_pred, csv_path): + df = pd.DataFrame(y_pred.numpy()) + df.to_csv(csv_path, index=False, header=False) def get_replicate_str(dataset, config): if dataset['dataset'].dataset_name == 'poverty': @@ -210,59 +203,3 @@ def get_model_prefix(dataset, config): config.log_dir, f"{dataset_name}_{replicate_str}_") return prefix - -def move_to(obj, device): - if isinstance(obj, dict): - return {k: move_to(v, device) for k, v in obj.items()} - elif isinstance(obj, list): - return [move_to(v, device) for v in obj] - elif isinstance(obj, float) or isinstance(obj, int): - return obj - else: - # Assume obj is a Tensor or other type - # (like Batch, for MolPCBA) that supports .to(device) - return obj.to(device) - -def detach_and_clone(obj): - if torch.is_tensor(obj): - return obj.detach().clone() - elif isinstance(obj, dict): - return {k: detach_and_clone(v) for k, v in obj.items()} - elif isinstance(obj, list): - return [detach_and_clone(v) for v in obj] - elif isinstance(obj, float) or isinstance(obj, int): - return obj - else: - raise TypeError("Invalid type for detach_and_clone") - -def collate_list(vec): - """ - If vec is a list of Tensors, it concatenates them all along the first dimension. - - If vec is a list of lists, it joins these lists together, but does not attempt to - recursively collate. This allows each element of the list to be, e.g., its own dict. - - If vec is a list of dicts (with the same keys in each dict), it returns a single dict - with the same keys. For each key, it recursively collates all entries in the list. - """ - if not isinstance(vec, list): - raise TypeError("collate_list must take in a list") - elem = vec[0] - if torch.is_tensor(elem): - return torch.cat(vec) - elif isinstance(elem, list): - return [obj for sublist in vec for obj in sublist] - elif isinstance(elem, dict): - return {k: collate_list([d[k] for d in vec]) for k in elem} - else: - raise TypeError("Elements of the list to collate must be tensors or dicts.") - -def remove_key(key): - """ - Returns a function that strips out a key from a dict. - """ - def remove(d): - if not isinstance(d, dict): - raise TypeError("remove_key must take in a dict") - return {k: v for (k,v) in d.items() if k != key} - return remove diff --git a/setup.py b/setup.py index 1fb6a2e5..9cd1f596 100644 --- a/setup.py +++ b/setup.py @@ -30,7 +30,6 @@ 'tqdm>=4.53.0', 'outdated>=0.2.0', 'pytz>=2020.4', - 'torchvision==0.8.2' ], license='MIT', packages=setuptools.find_packages(exclude=['dataset_preprocessing', 'examples', 'examples.models', 'examples.models.bert']), diff --git a/wilds/__init__.py b/wilds/__init__.py index fe1bed0a..77f0ad5a 100644 --- a/wilds/__init__.py +++ b/wilds/__init__.py @@ -10,7 +10,6 @@ 'poverty', 'fmow', 'py150', - 'gwhd', ] additional_datasets = [ diff --git a/wilds/common/data_loaders.py b/wilds/common/data_loaders.py index 0cb23cee..b806832a 100644 --- a/wilds/common/data_loaders.py +++ b/wilds/common/data_loaders.py @@ -4,20 +4,20 @@ from torch.utils.data.sampler import WeightedRandomSampler, SubsetRandomSampler from wilds.common.utils import get_counts, split_into_groups -def get_train_loader(loader, dataset, batch_size, +def get_train_loader(loader, dataset, batch_size, uniform_over_groups=None, grouper=None, distinct_groups=True, n_groups_per_batch=None, **loader_kwargs): """ Constructs and returns the data loader for training. Args: - loader (str): Loader type. 'standard' for standard loaders and 'group' for group loaders, - which first samples groups and then samples a fixed number of examples belonging + which first samples groups and then samples a fixed number of examples belonging to each group. - dataset (WILDSDataset or WILDSSubset): Data - batch_size (int): Batch size - - uniform_over_groups (None or bool): Whether to sample the groups uniformly or according to the + - uniform_over_groups (None or bool): Whether to sample the groups uniformly or according to the natural data distribution. - Setting to None applies the defaults for each type of loaders. - For standard loaders, the default is False. For group loaders, + Setting to None applies the defaults for each type of loaders. + For standard loaders, the default is False. For group loaders, the default is True. - grouper (Grouper): Grouper used for group loaders or for uniform_over_groups=True - distinct_groups (bool): Whether to sample distinct_groups within each minibatch for group loaders. @@ -82,7 +82,7 @@ def get_eval_loader(loader, dataset, batch_size, grouper=None, **loader_kwargs): """ Constructs and returns the data loader for evaluation. Args: - - loader (str): Loader type. 'standard' for standard loaders. + - loader (str): Loader type. 'standard' for standard loaders. - dataset (WILDSDataset or WILDSSubset): Data - batch_size (int): Batch size - loader_kwargs: kwargs passed into torch DataLoader initialization. @@ -104,6 +104,7 @@ class GroupSampler: then sampling data from those groups. It drops the last batch if it's incomplete. """ + def __init__(self, group_ids, batch_size, n_groups_per_batch, uniform_over_groups, distinct_groups): @@ -128,13 +129,16 @@ def __init__(self, group_ids, batch_size, n_groups_per_batch, self.group_prob = unique_counts.numpy() / unique_counts.numpy().sum() def __iter__(self): + for batch_id in range(self.num_batches): + # Note that we are selecting group indices rather than groups groups_for_batch = np.random.choice( len(self.unique_groups), size=self.n_groups_per_batch, replace=(not self.distinct_groups), p=self.group_prob) + sampled_ids = [ np.random.choice( self.group_indices[group], @@ -145,6 +149,7 @@ def __iter__(self): # Flatten sampled_ids = np.concatenate(sampled_ids) + yield sampled_ids def __len__(self): diff --git a/wilds/common/metrics/all_metrics.py b/wilds/common/metrics/all_metrics.py index ec93cc9f..0f5d7eb1 100644 --- a/wilds/common/metrics/all_metrics.py +++ b/wilds/common/metrics/all_metrics.py @@ -1,13 +1,10 @@ import torch import torch.nn as nn -from torchvision.ops.boxes import box_iou -from torchvision.models.detection._utils import Matcher -from torchvision.ops import nms, box_convert import numpy as np import torch.nn.functional as F from wilds.common.metrics.metric import Metric, ElementwiseMetric, MultiTaskMetric from wilds.common.metrics.loss import ElementwiseLoss -from wilds.common.utils import avg_over_groups, minimum, maximum, get_counts +from wilds.common.utils import avg_over_groups, minimum, maximum import sklearn.metrics from scipy.stats import pearsonr @@ -22,7 +19,7 @@ def binary_logits_to_score(logits): def multiclass_logits_to_pred(logits): """ - Takes multi-class logits of size (batch_size, ..., n_classes) and returns predictions + Takes multi-class logits of size (batch_size, ..., n_classes) and returns predictions by taking an argmax at the last dimension """ assert logits.dim() > 1 @@ -145,78 +142,3 @@ def _compute(self, y_pred, y_true): def worst(self, metrics): return minimum(metrics) - -class DummyMetric(Metric): - """ - For testing purposes. This Metric always returns -1. - """ - def __init__(self, prediction_fn=None, name=None): - self.prediction_fn = prediction_fn - if name is None: - name = 'dummy' - super().__init__(name=name) - - def _compute(self, y_pred, y_true): - return torch.tensor(-1) - - def _compute_group_wise(self, y_pred, y_true, g, n_groups): - group_metrics = torch.ones(n_groups, device=g.device) * -1 - group_counts = get_counts(g, n_groups) - worst_group_metric = self.worst(group_metrics) - return group_metrics, group_counts, worst_group_metric - - def worst(self, metrics): - return minimum(metrics) - -class DetectionAccuracy(ElementwiseMetric): - """ - Given a specific Intersection over union threshold, - determine the accuracy achieved for a one-class detector - """ - def __init__(self, iou_threshold=0.5, score_threshold=0.5, name=None): - self.iou_threshold = iou_threshold - self.score_threshold = score_threshold - if name is None: - name = "detection_acc" - super().__init__(name=name) - - def _compute_element_wise(self, y_pred, y_true): - batch_results = [] - for src_boxes, target in zip(y_true, y_pred): - target_boxes = target["boxes"] - target_scores = target["scores"] - - pred_boxes = target_boxes[target_scores > self.score_threshold] - det_accuracy = torch.mean(torch.stack([ self._accuracy(src_boxes["boxes"],pred_boxes,iou_thr) for iou_thr in np.arange(0.5,0.51,0.05)])) - batch_results.append(det_accuracy) - - return torch.tensor(batch_results) - - def _accuracy(self, src_boxes,pred_boxes , iou_threshold): - total_gt = len(src_boxes) - total_pred = len(pred_boxes) - if total_gt > 0 and total_pred > 0: - # Define the matcher and distance matrix based on iou - matcher = Matcher(iou_threshold,iou_threshold,allow_low_quality_matches=False) - match_quality_matrix = box_iou(src_boxes,pred_boxes) - results = matcher(match_quality_matrix) - true_positive = torch.count_nonzero(results.unique() != -1) - matched_elements = results[results > -1] - #in Matcher, a pred element can be matched only twice - false_positive = ( - torch.count_nonzero(results == -1) + - (len(matched_elements) - len(matched_elements.unique())) - ) - false_negative = total_gt - true_positive - acc = true_positive / ( true_positive + false_positive + false_negative ) - return true_positive / ( true_positive + false_positive + false_negative ) - elif total_gt == 0: - if total_pred > 0: - return torch.tensor(0.) - else: - return torch.tensor(1.) - elif total_gt > 0 and total_pred == 0: - return torch.tensor(0.) - - def worst(self, metrics): - return minimum(metrics) diff --git a/wilds/common/metrics/metric.py b/wilds/common/metrics/metric.py index 89582577..9c4372b0 100644 --- a/wilds/common/metrics/metric.py +++ b/wilds/common/metrics/metric.py @@ -1,5 +1,5 @@ import numpy as np -from wilds.common.utils import avg_over_groups, get_counts, numel +from wilds.common.utils import avg_over_groups, get_counts import torch class Metric: @@ -82,7 +82,7 @@ def compute(self, y_pred, y_true, return_dict=True): Output (return_dict=True): - results (dict): Dictionary of results, mapping metric.agg_metric_field to avg_metric """ - if numel(y_true) == 0: + if y_true.numel()==0: agg_metric = torch.tensor(0., device=y_true.device) else: agg_metric = self._compute(y_pred, y_true) @@ -133,7 +133,6 @@ def _compute_group_wise(self, y_pred, y_true, g, n_groups): self._compute( y_pred[g == group_idx], y_true[g == group_idx])) - group_metrics = torch.stack(group_metrics) worst_group_metric = self.worst(group_metrics[group_counts>0]) diff --git a/wilds/common/utils.py b/wilds/common/utils.py index ebf2f41d..7854393a 100644 --- a/wilds/common/utils.py +++ b/wilds/common/utils.py @@ -81,9 +81,8 @@ def avg_over_groups(v, g, n_groups): group_avgs (Tensor): Vector of length num_groups group_counts (Tensor) """ - - assert v.device==g.device + device = v.device assert v.numel()==g.numel() group_count = get_counts(g, n_groups) group_avgs = torch_scatter.scatter(src=v, index=g, dim_size=n_groups, reduce='mean') @@ -114,6 +113,7 @@ def subsample_idxs(idxs, num=5000, take_rest=False, seed=None): idxs = idxs[:num] return idxs + def shuffle_arr(arr, seed=None): seed = (seed + 548207) if seed is not None else None rng = np.random.default_rng(seed) @@ -126,11 +126,3 @@ def threshold_at_recall(y_pred, y_true, global_recall=60): """ Calculate the model threshold to use to achieve a desired global_recall level. Assumes that y_true is a vector of the true binary labels.""" return np.percentile(y_pred[y_true == 1], 100-global_recall) - -def numel(obj): - if torch.is_tensor(obj): - return obj.numel() - elif isinstance(obj, list): - return len(obj) - else: - raise TypeError("Invalid type for numel") diff --git a/wilds/datasets/gwhd_dataset.py b/wilds/datasets/gwhd_dataset.py deleted file mode 100644 index 221b480a..00000000 --- a/wilds/datasets/gwhd_dataset.py +++ /dev/null @@ -1,146 +0,0 @@ -import numpy as np -import pandas as pd -import torch -from pathlib import Path -from PIL import Image -from wilds.datasets.wilds_dataset import WILDSDataset -from wilds.common.grouper import CombinatorialGrouper -from wilds.common.metrics.all_metrics import DetectionAccuracy - - -def _collate_fn(batch): - """ - Stack x (batch[0]) and metadata (batch[2]), but not y. - originally, batch = (item1, item2, item3, item4) - after zip, batch = [(item1[0], item2[0], ..), ..] - """ - batch = list(zip(*batch)) - batch[0] = torch.stack(batch[0]) - batch[1] = list(batch[1]) - batch[2] = torch.stack(batch[2]) - return tuple(batch) - -class GWHDDataset(WILDSDataset): - """ - The GWHD-wilds wheat head localization dataset. - This is a modified version of the original Global Wheat Head Dataset. - This dataset is not part of the official WILDS benchmark. - We provide it for convenience and to reproduce observations discussed in the WILDS paper. - Supported `split_scheme`: - 'official' for WILDS related tasks. - To reproduce the baseline, several splits are needed: - - to train a model on train domains and test against a all test split: 'train_in-dist' - - "benchmark_biased" ; "benchmark_in-dist" - Input (x): - 1024x1024 RGB images of wheat field canopy between flowering and ripening. - Output (y): - y is a nx4-dimensional vector where each line represents a box coordinate (x_min,y_min,x_max,y_max) - Metadata: - Each image is annotated with the ID of the domain it came from (integer from 0 to 10). - Website: - http://www.global-wheat.com/ - Original publication: - @article{david_global_2020, - title = {Global {Wheat} {Head} {Detection} ({GWHD}) {Dataset}: {A} {Large} and {Diverse} {Dataset} of {High}-{Resolution} {RGB}-{Labelled} {Images} to {Develop} and {Benchmark} {Wheat} {Head} {Detection} {Methods}}, - volume = {2020}, - url = {https://doi.org/10.34133/2020/3521852}, - doi = {10.34133/2020/3521852}, - journal = {Plant Phenomics}, - author = {David, Etienne and Madec, Simon and Sadeghi-Tehran, Pouria and Aasen, Helge and Zheng, Bangyou and Liu, Shouyang and Kirchgessner, Norbert and Ishikawa, Goro and Nagasawa, Koichi and Badhon, Minhajul A. and Pozniak, Curtis and de Solan, Benoit and Hund, Andreas and Chapman, Scott C. and Baret, Frédéric and Stavness, Ian and Guo, Wei}, - month = aug, - year = {2020}, - note = {Publisher: AAAS}, - pages = {3521852}, - } - License: - This dataset is distributed under the MIT license. - https://github.com/snap-stanford/ogb/blob/master/LICENSE - """ - - _dataset_name = 'gwhd' - _versions_dict = { - '2.0': { - 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x42fa9775eacc453489a428abd59a437d/contents/blob/', - 'compressed_size': None}} - - def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'): - - self._version = version - self._data_dir = self.initialize_data_dir(root_dir, download) - self._original_resolution = (1024, 1024) - self.root = Path(self.data_dir) - self._is_detection = True - self._is_classification = False - self._y_size = None - self._n_classes = 1 - - self._split_scheme = split_scheme - - # Get filenames - - if split_scheme =="official": - train_data_df = pd.read_csv(self.root / f'official_train.csv') - val_data_df = pd.read_csv(self.root / f'official_val.csv') - test_data_df = pd.read_csv(self.root / f'official_test.csv') - - elif split_scheme == "benchmark_biased": - train_data_df = pd.read_csv(self.root / f'official_train.csv') - val_data_df = pd.read_csv(self.root / f'official_val.csv') - test_data_df = pd.read_csv(self.root / f'in-dist_test.csv') - - elif split_scheme == "benchmark_in-dist": - train_data_df = pd.read_csv(self.root / f'in-dist_train.csv') - val_data_df = pd.read_csv(self.root / f'official_val.csv') - test_data_df = pd.read_csv(self.root / f'in-dist_test.csv') - - - self._image_array = [] - self._split_array, self._y_array, self._metadata_array = [], [], [] - - for i, df in enumerate([train_data_df, val_data_df, test_data_df]): - self._image_array.extend(list(df['image'].values)) - labels = list(df['labels'].values) - self._split_array.extend([i] * len(labels)) - - labels = [{ - "boxes": torch.stack([ - torch.tensor([int(float(i)) for i in box.split(" ")]) - for box in boxes.split(";") - ]), - "labels": torch.tensor([1]*len(list(boxes.split(";")))).long() - } if type(boxes) != float else { - "boxes": torch.empty(0,4), - "labels": torch.empty(0,dtype=torch.long) - } for boxes in labels] - - self._y_array.extend(labels) - self._metadata_array.extend(list(df['group'].values)) - - self._split_array = np.array(self._split_array) - - self._metadata_array = torch.tensor(self._metadata_array, - dtype=torch.long).unsqueeze(1) - self._metadata_fields = ['location'] - - self._eval_grouper = CombinatorialGrouper( - dataset=self, - groupby_fields=['location']) - - self._metric = DetectionAccuracy() # TODO - self._collate = _collate_fn - - super().__init__(root_dir, download, split_scheme) - - def get_input(self, idx): - """ - Returns x for a given idx. - """ - img_filename = self.root / "images" / self._image_array[idx] - x = Image.open(img_filename) - return x - - def eval(self, y_pred, y_true, metadata): - return self.standard_group_eval( - self._metric, - self._eval_grouper, - y_pred, y_true, metadata) diff --git a/wilds/datasets/wilds_dataset.py b/wilds/datasets/wilds_dataset.py index 8812f957..1f8bf21a 100644 --- a/wilds/datasets/wilds_dataset.py +++ b/wilds/datasets/wilds_dataset.py @@ -95,8 +95,8 @@ def check_init(self): assert 'train' in self.split_dict assert 'val' in self.split_dict - # Check the form of the required arrays - assert (isinstance(self.y_array, torch.Tensor) or isinstance(self.y_array, list)) + # Check that required arrays are Tensors + assert isinstance(self.y_array, torch.Tensor), 'y_array must be a torch.Tensor' assert isinstance(self.metadata_array, torch.Tensor), 'metadata_array must be a torch.Tensor' # Check that dimensions match @@ -106,10 +106,6 @@ def check_init(self): # Check metadata assert len(self.metadata_array.shape) == 2 assert len(self.metadata_fields) == self.metadata_array.shape[1] - - # Check that it is not both classification and detection - assert not (self.is_classification and self.is_detection) - # For convenience, include y in metadata_fields if y_size == 1 if self.y_size == 1: assert 'y' in self.metadata_fields @@ -246,15 +242,9 @@ def n_classes(self): def is_classification(self): """ Boolean. True if the task is classification, and false otherwise. + Used for logging purposes. """ - return getattr(self, '_is_classification', (self.n_classes is not None)) - - @property - def is_detection(self): - """ - Boolean. True if the task is detection, and false otherwise. - """ - return getattr(self, '_is_detection', False) + return (self.n_classes is not None) @property def metadata_fields(self): @@ -453,7 +443,7 @@ def __init__(self, dataset, indices, transform): def __getitem__(self, idx): x, y, metadata = self.dataset[self.indices[idx]] if self.transform is not None: - x, y = self.transform(x, y) + x = self.transform(x) return x, y, metadata def __len__(self): diff --git a/wilds/get_dataset.py b/wilds/get_dataset.py index cfa5f2c7..1073100f 100644 --- a/wilds/get_dataset.py +++ b/wilds/get_dataset.py @@ -55,7 +55,7 @@ def get_dataset(dataset, version=None, **dataset_kwargs): elif dataset == 'poverty': if version == '1.0': from wilds.datasets.archive.poverty_v1_0_dataset import PovertyMapDataset - else: + else: from wilds.datasets.poverty_dataset import PovertyMapDataset return PovertyMapDataset(version=version, **dataset_kwargs) @@ -77,7 +77,3 @@ def get_dataset(dataset, version=None, **dataset_kwargs): elif dataset == 'sqf': from wilds.datasets.sqf_dataset import SQFDataset return SQFDataset(version=version, **dataset_kwargs) - - elif dataset == 'gwhd': - from wilds.datasets.gwhd_dataset import GWHDDataset - return GWHDDataset(version=version, **dataset_kwargs) From d9fd59ffbb5f7cab83d200984a5b4021b90e2562 Mon Sep 17 00:00:00 2001 From: Etienne David Date: Sat, 8 May 2021 16:53:36 +0200 Subject: [PATCH 173/244] clean fasterrcnn and addressing most questions. Constitution of v0.9 is still pending --- .gitignore | 8 --- .vscode/settings.json | 3 - examples/configs/model.py | 3 +- examples/models/detection/fasterrcnn.py | 12 ++++ examples/models/initializer.py | 2 +- wilds/datasets/gwhd_dataset.py | 80 ++++++++++++++++--------- 6 files changed, 66 insertions(+), 42 deletions(-) delete mode 100644 .vscode/settings.json diff --git a/.gitignore b/.gitignore index e3b37d5a..ac33582a 100644 --- a/.gitignore +++ b/.gitignore @@ -2,11 +2,3 @@ __pycache__ build dist wilds.egg-info -data -logs -test_faster -paper* -.vscode -*sh -*ipynb -experiences \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index d0c7592c..00000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "python.pythonPath": "/home/gdaubige/anaconda3/envs/wilds/bin/python" -} \ No newline at end of file diff --git a/examples/configs/model.py b/examples/configs/model.py index 37afbf22..582e2d6c 100644 --- a/examples/configs/model.py +++ b/examples/configs/model.py @@ -39,7 +39,8 @@ 'logistic_regression': {}, 'fasterrcnn': { 'model_kwargs': { - 'pretrained': True, + 'pretrained_model': True, + 'pretrained_backbone': True, } } } diff --git a/examples/models/detection/fasterrcnn.py b/examples/models/detection/fasterrcnn.py index d553ea0f..6abb6b16 100644 --- a/examples/models/detection/fasterrcnn.py +++ b/examples/models/detection/fasterrcnn.py @@ -1,3 +1,15 @@ +""" +This module contains all the necessary modifications to adapt "Faster-RCNN" of the torchvision library +in order to be able to calculate the loss per image + +It has been developped from torchvision=0.8.2 and did not has been tested on other versions + +All credits : +https://github.com/pytorch/vision/blob/master/LICENSE +https://github.com/pytorch/vision/tree/master/torchvision/models/detection + +""" + import torch import torch.nn as nn import torchvision diff --git a/examples/models/initializer.py b/examples/models/initializer.py index a9580c93..b9a9bb7e 100644 --- a/examples/models/initializer.py +++ b/examples/models/initializer.py @@ -162,6 +162,6 @@ def initialize_fasterrcnn_model(config, d_out): from models.detection.fasterrcnn import fasterrcnn_resnet50_fpn # load a model pre-trained pre-trained on COCO - model = fasterrcnn_resnet50_fpn(pretrained=config.model_kwargs["pretrained"],num_classes=d_out) + model = fasterrcnn_resnet50_fpn(pretrained=config.model_kwargs["pretrained_model"],pretrained_backbone=config.model_kwargs["pretrained_backbone"],num_classes=d_out) return model diff --git a/wilds/datasets/gwhd_dataset.py b/wilds/datasets/gwhd_dataset.py index 221b480a..736a0b50 100644 --- a/wilds/datasets/gwhd_dataset.py +++ b/wilds/datasets/gwhd_dataset.py @@ -7,7 +7,21 @@ from wilds.common.grouper import CombinatorialGrouper from wilds.common.metrics.all_metrics import DetectionAccuracy - +def decode_string(BoxesString): + """ + Small method to decode the BoxesString + """ + if BoxesString == "no_box": + return np.zeros((0,4)) + else: + try: + boxes = np.array([np.array([int(i) for i in box.split(" ")]) + for box in BoxesString.split(";")]) + return boxes + except: + print(BoxesString) + print("Submission is not well formatted. empty boxes will be returned") + return np.zeros((0,4)) def _collate_fn(batch): """ Stack x (batch[0]) and metadata (batch[2]), but not y. @@ -23,20 +37,18 @@ def _collate_fn(batch): class GWHDDataset(WILDSDataset): """ The GWHD-wilds wheat head localization dataset. - This is a modified version of the original Global Wheat Head Dataset. + This is a modified version of the original Global Wheat Head Dataset 2021. This dataset is not part of the official WILDS benchmark. We provide it for convenience and to reproduce observations discussed in the WILDS paper. Supported `split_scheme`: - 'official' for WILDS related tasks. - To reproduce the baseline, several splits are needed: - - to train a model on train domains and test against a all test split: 'train_in-dist' - - "benchmark_biased" ; "benchmark_in-dist" + - 'official' for WILDS related tasks. + - 'in-dist' and 'ood_with_subsampled_test' to reproduce the baseline described in the paper. WARNING: these splits are not accessible before v1.0 Input (x): - 1024x1024 RGB images of wheat field canopy between flowering and ripening. + 1024x1024 RGB images of wheat field canopy starting from anthesis (flowering) to ripening. Output (y): - y is a nx4-dimensional vector where each line represents a box coordinate (x_min,y_min,x_max,y_max) + y is a nx4-dimensional vector where each line represents a box coordinate (x_min, y_min, x_max, y_max) Metadata: - Each image is annotated with the ID of the domain it came from (integer from 0 to 10). + Each image is annotated with the ID of the domain (location_date_sensor) it came from (integer from 0 to 46). Website: http://www.global-wheat.com/ Original publication: @@ -54,13 +66,15 @@ class GWHDDataset(WILDSDataset): } License: This dataset is distributed under the MIT license. - https://github.com/snap-stanford/ogb/blob/master/LICENSE """ _dataset_name = 'gwhd' + + # Version 0.9 corresponds to the final dataset, but without the test label. It can be used to train + # a model but no validation nor test metrics are available before 5th July 2021 _versions_dict = { - '2.0': { - 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x42fa9775eacc453489a428abd59a437d/contents/blob/', + '0.9': { + 'download_url': '', 'compressed_size': None}} def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'): @@ -83,13 +97,19 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' val_data_df = pd.read_csv(self.root / f'official_val.csv') test_data_df = pd.read_csv(self.root / f'official_test.csv') - elif split_scheme == "benchmark_biased": - train_data_df = pd.read_csv(self.root / f'official_train.csv') - val_data_df = pd.read_csv(self.root / f'official_val.csv') - test_data_df = pd.read_csv(self.root / f'in-dist_test.csv') - - elif split_scheme == "benchmark_in-dist": - train_data_df = pd.read_csv(self.root / f'in-dist_train.csv') + elif split_scheme == "ood_with_subsampled_test": + if version == "0.9": + print("Warning: ood_with_subsampled_test is not available in 0.9") + else: + train_data_df = pd.read_csv(self.root / f'official_train.csv') + val_data_df = pd.read_csv(self.root / f'official_val.csv') + test_data_df = pd.read_csv(self.root / f'in-dist_test.csv') + + elif split_scheme == "in-dist": + if version == "0.9": + print("Warning: ood_with_subsampled_test is not available in 0.9") + else: + train_data_df = pd.read_csv(self.root / f'in-dist_train.csv') val_data_df = pd.read_csv(self.root / f'official_val.csv') test_data_df = pd.read_csv(self.root / f'in-dist_test.csv') @@ -99,19 +119,21 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' for i, df in enumerate([train_data_df, val_data_df, test_data_df]): self._image_array.extend(list(df['image'].values)) - labels = list(df['labels'].values) - self._split_array.extend([i] * len(labels)) - + boxes_string = list(df['BoxesString'].values) + all_boxes = [decode_string(box_string) for box_string in boxes_string] + + self._split_array.extend([i] * len(all_boxes)) + labels = [{ "boxes": torch.stack([ - torch.tensor([int(float(i)) for i in box.split(" ")]) - for box in boxes.split(";") + torch.tensor(box) + for box in boxes ]), "labels": torch.tensor([1]*len(list(boxes.split(";")))).long() - } if type(boxes) != float else { + } if len(boxes) > 0 else { "boxes": torch.empty(0,4), "labels": torch.empty(0,dtype=torch.long) - } for boxes in labels] + } for boxes in all_boxes] self._y_array.extend(labels) self._metadata_array.extend(list(df['group'].values)) @@ -120,13 +142,13 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' self._metadata_array = torch.tensor(self._metadata_array, dtype=torch.long).unsqueeze(1) - self._metadata_fields = ['location'] + self._metadata_fields = ['domain'] self._eval_grouper = CombinatorialGrouper( dataset=self, - groupby_fields=['location']) + groupby_fields=['domain']) - self._metric = DetectionAccuracy() # TODO + self._metric = DetectionAccuracy() self._collate = _collate_fn super().__init__(root_dir, download, split_scheme) From 468a9a1bb30f998aaebb2274861180619b0abf77 Mon Sep 17 00:00:00 2001 From: Etienne David Date: Sun, 9 May 2021 20:58:29 +0200 Subject: [PATCH 174/244] v0.9 ready --- examples/configs/datasets.py | 2 +- wilds/datasets/gwhd_dataset.py | 17 +++++++++-------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index 17bb8640..47030ebb 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -291,7 +291,7 @@ 'n_classes': 1, 'pretrained': True}, 'loss_function': 'fasterrcnn_criterion', - 'groupby_fields': ['location'], + 'groupby_fields': ['location_date_sensor'], 'val_metric': 'detection_acc_avg', # TODO 'val_metric_decreasing': False, 'algo_log_metric': None, # TODO diff --git a/wilds/datasets/gwhd_dataset.py b/wilds/datasets/gwhd_dataset.py index 736a0b50..432154b3 100644 --- a/wilds/datasets/gwhd_dataset.py +++ b/wilds/datasets/gwhd_dataset.py @@ -74,7 +74,7 @@ class GWHDDataset(WILDSDataset): # a model but no validation nor test metrics are available before 5th July 2021 _versions_dict = { '0.9': { - 'download_url': '', + 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x8ba9122a41454997afdfb78762d390cf/contents/blob/', 'compressed_size': None}} def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'): @@ -110,17 +110,18 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' print("Warning: ood_with_subsampled_test is not available in 0.9") else: train_data_df = pd.read_csv(self.root / f'in-dist_train.csv') - val_data_df = pd.read_csv(self.root / f'official_val.csv') - test_data_df = pd.read_csv(self.root / f'in-dist_test.csv') + val_data_df = pd.read_csv(self.root / f'official_val.csv') + test_data_df = pd.read_csv(self.root / f'in-dist_test.csv') self._image_array = [] self._split_array, self._y_array, self._metadata_array = [], [], [] for i, df in enumerate([train_data_df, val_data_df, test_data_df]): - self._image_array.extend(list(df['image'].values)) + self._image_array.extend(list(df['image_name'].values)) boxes_string = list(df['BoxesString'].values) all_boxes = [decode_string(box_string) for box_string in boxes_string] + self._split_array.extend([i] * len(all_boxes)) @@ -129,24 +130,24 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' torch.tensor(box) for box in boxes ]), - "labels": torch.tensor([1]*len(list(boxes.split(";")))).long() + "labels": torch.tensor([1]*len(boxes)).long() } if len(boxes) > 0 else { "boxes": torch.empty(0,4), "labels": torch.empty(0,dtype=torch.long) } for boxes in all_boxes] self._y_array.extend(labels) - self._metadata_array.extend(list(df['group'].values)) + self._metadata_array.extend(list(df['domain'].values)) self._split_array = np.array(self._split_array) self._metadata_array = torch.tensor(self._metadata_array, dtype=torch.long).unsqueeze(1) - self._metadata_fields = ['domain'] + self._metadata_fields = ['location_date_sensor'] self._eval_grouper = CombinatorialGrouper( dataset=self, - groupby_fields=['domain']) + groupby_fields=['location_date_sensor']) self._metric = DetectionAccuracy() self._collate = _collate_fn From 96d3a8eec5b2987f4f35ace731e38236ff44c52d Mon Sep 17 00:00:00 2001 From: kohpangwei Date: Sun, 9 May 2021 14:39:25 -0700 Subject: [PATCH 175/244] Revert "Revert "GWHD and FasterRCNN support"" --- .gitignore | 8 + .vscode/settings.json | 3 + examples/algorithms/algorithm.py | 15 +- examples/algorithms/group_algorithm.py | 4 +- examples/algorithms/initializer.py | 29 +- examples/algorithms/single_model_algorithm.py | 17 +- examples/configs/datasets.py | 26 + examples/configs/model.py | 11 +- examples/configs/supported.py | 30 +- examples/losses.py | 23 + examples/models/detection/fasterrcnn.py | 511 ++++++++++++++++++ examples/models/initializer.py | 56 +- examples/run_expt.py | 22 +- examples/train.py | 45 +- examples/transforms.py | 18 +- examples/utils.py | 69 ++- setup.py | 1 + wilds/__init__.py | 1 + wilds/common/data_loaders.py | 17 +- wilds/common/metrics/all_metrics.py | 82 ++- wilds/common/metrics/metric.py | 5 +- wilds/common/utils.py | 12 +- wilds/datasets/gwhd_dataset.py | 146 +++++ wilds/datasets/wilds_dataset.py | 20 +- wilds/get_dataset.py | 6 +- 25 files changed, 1069 insertions(+), 108 deletions(-) create mode 100644 .vscode/settings.json create mode 100644 examples/losses.py create mode 100644 examples/models/detection/fasterrcnn.py create mode 100644 wilds/datasets/gwhd_dataset.py diff --git a/.gitignore b/.gitignore index 1d3b5479..acf51ee9 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,11 @@ build dist venv wilds.egg-info +data +logs +test_faster +paper* +.vscode +*sh +*ipynb +experiences \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..d0c7592c --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "python.pythonPath": "/home/gdaubige/anaconda3/envs/wilds/bin/python" +} \ No newline at end of file diff --git a/examples/algorithms/algorithm.py b/examples/algorithms/algorithm.py index c93d960a..5c734766 100644 --- a/examples/algorithms/algorithm.py +++ b/examples/algorithms/algorithm.py @@ -1,5 +1,6 @@ import torch import torch.nn as nn +from utils import move_to, detach_and_clone class Algorithm(nn.Module): def __init__(self, device): @@ -93,19 +94,13 @@ def sanitize_dict(self, in_dict, to_out_device=True): Helper function that sanitizes dictionaries by: - moving to the specified output device - removing any gradient information - - turning any Tensor of size 1 to a simple number + - detaching and cloning the tensors Args: - in_dict (dictionary) Output: - out_dict (dictionary): sanitized version of in_dict """ - out_dict = {} - for k, v in in_dict.items(): - if isinstance(v, torch.Tensor): - v_out = v.detach().clone() - if to_out_device: - v_out = v_out.to(self.out_device) - else: - v_out = v - out_dict[k] = v_out + out_dict = detach_and_clone(in_dict) + if to_out_device: + out_dict = move_to(out_dict, self.out_device) return out_dict diff --git a/examples/algorithms/group_algorithm.py b/examples/algorithms/group_algorithm.py index 54cac1a8..eb0b95c2 100644 --- a/examples/algorithms/group_algorithm.py +++ b/examples/algorithms/group_algorithm.py @@ -3,7 +3,7 @@ from algorithms.algorithm import Algorithm from utils import update_average from scheduler import step_scheduler -from wilds.common.utils import get_counts +from wilds.common.utils import get_counts, numel class GroupAlgorithm(Algorithm): """ @@ -57,7 +57,7 @@ def update_log(self, results): results['y_pred'], results['y_true'], return_dict=False).item() - count = results['y_true'].numel() + count = numel(results['y_true']) # transfer other statistics in the results dictionary for field in self.logged_fields: diff --git a/examples/algorithms/initializer.py b/examples/algorithms/initializer.py index 00748cfc..180e9ff5 100644 --- a/examples/algorithms/initializer.py +++ b/examples/algorithms/initializer.py @@ -3,7 +3,8 @@ from algorithms.groupDRO import GroupDRO from algorithms.deepCORAL import DeepCORAL from algorithms.IRM import IRM -from configs.supported import algo_log_metrics, losses +from configs.supported import algo_log_metrics +from losses import initialize_loss def initialize_algorithm(config, datasets, train_grouper): train_dataset = datasets['train']['dataset'] @@ -11,23 +12,27 @@ def initialize_algorithm(config, datasets, train_grouper): # Configure the final layer of the networks used # The code below are defaults. Edit this if you need special config for your model. - if (train_dataset.is_classification) and (train_dataset.y_size == 1): - # For single-task classification, we have one output per class + if train_dataset.is_classification: + if train_dataset.y_size == 1: + # For single-task classification, we have one output per class + d_out = train_dataset.n_classes + elif train_dataset.y_size is None: + d_out = train_dataset.n_classes + elif (train_dataset.y_size > 1) and (train_dataset.n_classes == 2): + # For multi-task binary classification (each output is the logit for each binary class) + d_out = train_dataset.y_size + else: + raise RuntimeError('d_out not defined.') + elif train_dataset.is_detection: + # For detection, d_out is the number of classes d_out = train_dataset.n_classes - elif (train_dataset.is_classification) and (train_dataset.y_size is None): - d_out = train_dataset.n_classes - elif (train_dataset.is_classification) and (train_dataset.y_size > 1) and (train_dataset.n_classes == 2): - # For multi-task binary classification (each output is the logit for each binary class) - d_out = train_dataset.y_size - elif (not train_dataset.is_classification): + else: # For regression, we have one output per target dimension d_out = train_dataset.y_size - else: - raise RuntimeError('d_out not defined.') # Other config n_train_steps = len(train_loader) * config.n_epochs - loss = losses[config.loss_function] + loss = initialize_loss(config, d_out) metric = algo_log_metrics[config.algo_log_metric] if config.algorithm=='ERM': diff --git a/examples/algorithms/single_model_algorithm.py b/examples/algorithms/single_model_algorithm.py index e368b88f..f01c21bb 100644 --- a/examples/algorithms/single_model_algorithm.py +++ b/examples/algorithms/single_model_algorithm.py @@ -3,6 +3,7 @@ from scheduler import initialize_scheduler from optimizer import initialize_optimizer from torch.nn.utils import clip_grad_norm_ +from utils import move_to class SingleModelAlgorithm(GroupAlgorithm): """ @@ -47,11 +48,19 @@ def process_batch(self, batch): - y_true """ x, y_true, metadata = batch - x = x.to(self.device) - y_true = y_true.to(self.device) - g = self.grouper.metadata_to_group(metadata).to(self.device) - outputs = self.model(x) + x = move_to(x, self.device) + y_true = move_to(y_true, self.device) + g = move_to(self.grouper.metadata_to_group(metadata), self.device) + + if self.model.needs_y: + if self.training: + outputs = self.model(x, y_true) + else: + outputs = self.model(x, None) + else: + outputs = self.model(x) + results = { 'g': g, 'y_true': y_true, diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index cd2d1d6f..17bb8640 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -282,6 +282,32 @@ 'n_epochs': 4, 'process_outputs_function': None, }, + 'gwhd': { + 'split_scheme': 'official', + 'model': 'fasterrcnn', + 'train_transform': 'image_base', + 'eval_transform': 'image_base', + 'model_kwargs': { + 'n_classes': 1, + 'pretrained': True}, + 'loss_function': 'fasterrcnn_criterion', + 'groupby_fields': ['location'], + 'val_metric': 'detection_acc_avg', # TODO + 'val_metric_decreasing': False, + 'algo_log_metric': None, # TODO + 'optimizer': 'Adam', + 'optimizer_kwargs': {}, + 'scheduler': None, + 'batch_size': 4, + 'lr': 1e-5, + 'weight_decay': 1e-4, + 'n_epochs': 10, + 'loader_kwargs': { + 'num_workers': 1, + 'pin_memory': True, + }, + 'process_outputs_function': None, + } } ########################################## diff --git a/examples/configs/model.py b/examples/configs/model.py index 46714bbe..37afbf22 100644 --- a/examples/configs/model.py +++ b/examples/configs/model.py @@ -15,19 +15,19 @@ 'scheduler': 'linear_schedule_with_warmup', }, 'densenet121': { - 'model_kwargs':{ + 'model_kwargs': { 'pretrained':True, }, 'target_resolution': (224, 224), }, 'wideresnet50': { - 'model_kwargs':{ + 'model_kwargs': { 'pretrained':True, }, 'target_resolution': (224, 224), }, 'resnet50': { - 'model_kwargs':{ + 'model_kwargs': { 'pretrained':True, }, 'target_resolution': (224, 224), @@ -37,4 +37,9 @@ 'target_resolution': (224, 224), }, 'logistic_regression': {}, + 'fasterrcnn': { + 'model_kwargs': { + 'pretrained': True, + } + } } diff --git a/examples/configs/supported.py b/examples/configs/supported.py index 8b66b74e..b0fd9aab 100644 --- a/examples/configs/supported.py +++ b/examples/configs/supported.py @@ -1,18 +1,6 @@ -import torch.nn as nn -import torch -import sys, os - # metrics -from wilds.common.metrics.loss import ElementwiseLoss, Loss, MultiTaskLoss from wilds.common.metrics.all_metrics import Accuracy, MultiTaskAccuracy, MSE, multiclass_logits_to_pred, binary_logits_to_pred -losses = { - 'cross_entropy': ElementwiseLoss(loss_fn=nn.CrossEntropyLoss(reduction='none')), - 'lm_cross_entropy': MultiTaskLoss(loss_fn=nn.CrossEntropyLoss(reduction='none')), - 'mse': MSE(name='loss'), - 'multitask_bce': MultiTaskLoss(loss_fn=nn.BCEWithLogitsLoss(reduction='none')), -} - algo_log_metrics = { 'accuracy': Accuracy(prediction_fn=multiclass_logits_to_pred), 'mse': MSE(), @@ -27,11 +15,23 @@ None: None, } -# see initialize_*() functions for correspondence -transforms = ['bert', 'image_base', 'image_resize_and_center_crop', 'poverty_train'] +# See models/initializer.py models = ['resnet18_ms', 'resnet50', 'resnet34', 'wideresnet50', 'densenet121', 'bert-base-uncased', 'distilbert-base-uncased', - 'gin-virtual', 'logistic_regression', 'code-gpt-py'] + 'gin-virtual', 'logistic_regression', 'code-gpt-py', + 'fasterrcnn'] + +# See algorithms/initializer.py algorithms = ['ERM', 'groupDRO', 'deepCORAL', 'IRM'] + +# See optimizer.py optimizers = ['SGD', 'Adam', 'AdamW'] + +# See scheduler.py schedulers = ['linear_schedule_with_warmup', 'ReduceLROnPlateau', 'StepLR'] + +# See transforms.py +transforms = ['bert', 'image_base', 'image_resize_and_center_crop', 'poverty_train'] + +# See losses.py +losses = ['cross_entropy', 'lm_cross_entropy', 'MSE', 'multitask_bce', 'fasterrcnn_criterion'] diff --git a/examples/losses.py b/examples/losses.py new file mode 100644 index 00000000..e969a10e --- /dev/null +++ b/examples/losses.py @@ -0,0 +1,23 @@ +import torch.nn as nn +from wilds.common.metrics.loss import ElementwiseLoss, Loss, MultiTaskLoss +from wilds.common.metrics.all_metrics import MSE + +def initialize_loss(config, d_out): + if config.loss_function == 'cross_entropy': + return ElementwiseLoss(loss_fn=nn.CrossEntropyLoss(reduction='none')) + + elif config.loss_function == 'lm_cross_entropy': + return MultiTaskLoss(loss_fn=nn.CrossEntropyLoss(reduction='none')) + + elif config.loss_function == 'mse': + return MSE(name='loss') + + elif config.loss_function == 'multitask_bce': + return MultiTaskLoss(loss_fn=nn.BCEWithLogitsLoss(reduction='none')) + + elif config.loss_function == 'fasterrcnn_criterion': + from examples.models.detection.fasterrcnn import FasterRCNNLoss + return ElementwiseLoss(loss_fn=FasterRCNNLoss(config.device)) + + else: + raise ValueError(f'config.loss_function {config.loss_function} not recognized') diff --git a/examples/models/detection/fasterrcnn.py b/examples/models/detection/fasterrcnn.py new file mode 100644 index 00000000..d553ea0f --- /dev/null +++ b/examples/models/detection/fasterrcnn.py @@ -0,0 +1,511 @@ +import torch +import torch.nn as nn +import torchvision +from collections import OrderedDict +import torch +from torch import nn, Tensor +import warnings +from typing import Tuple, List, Dict, Optional, Union + +from torch import nn + + +import torchvision +from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, FasterRCNN +from torchvision.models.detection.backbone_utils import resnet_fpn_backbone +from torchvision.models.utils import load_state_dict_from_url + + +from torchvision.ops import misc as misc_nn_ops +from torchvision.ops import MultiScaleRoIAlign + + +from torchvision.models.detection.anchor_utils import AnchorGenerator +from torchvision.models.detection.generalized_rcnn import GeneralizedRCNN +from torchvision.models.detection.faster_rcnn import TwoMLPHead + +from torchvision.models.detection.rpn import RPNHead, RegionProposalNetwork, concat_box_prediction_layers,permute_and_flatten +from torchvision.models.detection.roi_heads import RoIHeads + +from torchvision.models.detection import _utils as det_utils +from torch.nn import functional as F +from torchvision.models.detection.transform import GeneralizedRCNNTransform + + +model_urls = { + 'fasterrcnn_resnet50_fpn_coco': + 'https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth', + 'fasterrcnn_mobilenet_v3_large_320_fpn_coco': + 'https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth', + 'fasterrcnn_mobilenet_v3_large_fpn_coco': + 'https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth' +} + + +def batch_concat_box_prediction_layers(box_cls, box_regression): + # type: (List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor] + box_cls_flattened = [] + box_regression_flattened = [] + # for each feature level, permute the outputs to make them be in the + # same format as the labels. Note that the labels are computed for + # all feature levels concatenated, so we keep the same representation + # for the objectness and the box_regression + for box_cls_per_level, box_regression_per_level in zip( + box_cls, box_regression + ): + N, AxC, H, W = box_cls_per_level.shape + Ax4 = box_regression_per_level.shape[1] + A = Ax4 // 4 + C = AxC // A + box_cls_per_level = permute_and_flatten( + box_cls_per_level, N, A, C, H, W + ) + box_cls_flattened.append(box_cls_per_level) + + box_regression_per_level = permute_and_flatten( + box_regression_per_level, N, A, 4, H, W + ) + box_regression_flattened.append(box_regression_per_level) + # concatenate on the first dimension (representing the feature levels), to + # take into account the way the labels were generated (with all feature maps + # being concatenated as well) + + batch_size = box_regression_flattened[0].shape[0] + + new_box_cls = [] + new_box_regression = [] + for batch_idx in range(batch_size): + element_box_cls = [torch.unsqueeze(item[batch_idx],dim=0) for item in box_cls_flattened] + element_box_regression = [torch.unsqueeze(item[batch_idx],dim=0) for item in box_regression_flattened] + + element_box_cls = torch.cat(element_box_cls, dim=1).flatten(0, -2) + element_box_regression = torch.cat(element_box_regression, dim=1).reshape(-1, 4) + new_box_cls.append(element_box_cls) + new_box_regression.append(element_box_regression) + + + return new_box_cls, new_box_regression + +class RegionProposalNetworkWILDS(RegionProposalNetwork): + def __init__(self, + anchor_generator, + head, + # + fg_iou_thresh, bg_iou_thresh, + batch_size_per_image, positive_fraction, + # + pre_nms_top_n, post_nms_top_n, nms_thresh): + super().__init__(anchor_generator, + head, + fg_iou_thresh, bg_iou_thresh, + batch_size_per_image, positive_fraction, + pre_nms_top_n, post_nms_top_n, nms_thresh) + + def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets): + # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor] + """ + Arguments: + objectness (Tensor) + pred_bbox_deltas (Tensor) + labels (List[Tensor]) + regression_targets (List[Tensor]) + Returns: + objectness_loss (Tensor) + box_loss (Tensor) + """ + objectness, pred_bbox_deltas = batch_concat_box_prediction_layers(objectness, pred_bbox_deltas) + + objectness_loss = [] + box_loss = [] + + for objectness_, regression_targets_,labels_,objectness_,pred_bbox_deltas_ in zip(objectness,regression_targets,labels,objectness,pred_bbox_deltas): + + sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(torch.unsqueeze(labels_,dim=0)) + + sampled_pos_inds = torch.where(torch.cat(sampled_pos_inds, dim=0))[0] + sampled_neg_inds = torch.where(torch.cat(sampled_neg_inds, dim=0))[0] + sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0) + + box_loss.append(det_utils.smooth_l1_loss( + pred_bbox_deltas_[sampled_pos_inds], + regression_targets_[sampled_pos_inds], + beta=1 / 9, + size_average=False, + ) / (sampled_inds.numel())) + + + objectness_loss.append(F.binary_cross_entropy_with_logits( + objectness_[sampled_inds].flatten(), labels_[sampled_inds] + )) + + return torch.stack(objectness_loss), torch.stack(box_loss) + + def forward(self, + images, # type: ImageList + features, # type: Dict[str, Tensor] + targets=None # type: Optional[List[Dict[str, Tensor]]] + ): + # type: (...) -> Tuple[List[Tensor], Dict[str, Tensor]] + """ + Arguments: + images (ImageList): images for which we want to compute the predictions + features (OrderedDict[Tensor]): features computed from the images that are + used for computing the predictions. Each tensor in the list + correspond to different feature levels + targets (List[Dict[Tensor]]): ground-truth boxes present in the image (optional). + If provided, each element in the dict should contain a field `boxes`, + with the locations of the ground-truth boxes. + Returns: + boxes (List[Tensor]): the predicted boxes from the RPN, one Tensor per + image. + losses (Dict[Tensor]): the losses for the model during training. During + testing, it is an empty dict. + """ + # RPN uses all feature maps that are available + features = list(features.values()) + objectness, pred_bbox_deltas = self.head(features) + anchors = self.anchor_generator(images, features) + + num_images = len(anchors) + num_anchors_per_level_shape_tensors = [o[0].shape for o in objectness] + num_anchors_per_level = [s[0] * s[1] * s[2] for s in num_anchors_per_level_shape_tensors] + + raw_objectness = objectness + raw_pred_bbox_deltas = pred_bbox_deltas + objectness, pred_bbox_deltas = \ + concat_box_prediction_layers(objectness, pred_bbox_deltas) + # apply pred_bbox_deltas to anchors to obtain the decoded proposals + # note that we detach the deltas because Faster R-CNN do not backprop through + # the proposals + proposals = self.box_coder.decode(pred_bbox_deltas.detach(), anchors) + proposals = proposals.view(num_images, -1, 4) + + boxes, scores = self.filter_proposals(proposals, objectness, images.image_sizes, num_anchors_per_level) + losses = {} + + if self.training: + assert targets is not None + labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets) + regression_targets = self.box_coder.encode(matched_gt_boxes, anchors) + loss_objectness, loss_rpn_box_reg = self.compute_loss( + raw_objectness, raw_pred_bbox_deltas, labels, regression_targets) + + losses = { + "loss_objectness": loss_objectness, + "loss_rpn_box_reg": loss_rpn_box_reg, + } + return boxes, losses + +def fastrcnn_loss(class_logits, box_regression, labels, regression_targets): + # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor] + """ + Computes the loss for Faster R-CNN. + Arguments: + class_logits (Tensor) + box_regression (Tensor) + labels (list[BoxList]) + regression_targets (Tensor) + Returns: + classification_loss (Tensor) + box_loss (Tensor) + """ + + class_logits = torch.split(class_logits, 512,dim=0) + box_regression = torch.split(box_regression, 512,dim=0) + + classification_loss = [] + box_loss = [] + + for class_logits_, box_regression_, labels_, regression_targets_ in zip(class_logits, box_regression, labels, regression_targets): + classification_loss.append(F.cross_entropy(class_logits_, labels_)) + # get indices that correspond to the regression targets for + # the corresponding ground truth labels, to be used with + # advanced indexing + sampled_pos_inds_subset = torch.where(labels_ > 0)[0] + + labels_pos = labels_[sampled_pos_inds_subset] + N, num_classes = class_logits_.shape + + box_regression_ = box_regression_.reshape(N, -1, 4) + + box_loss_ = det_utils.smooth_l1_loss( + box_regression_[sampled_pos_inds_subset, labels_pos], + regression_targets_[sampled_pos_inds_subset], + beta=1 / 9, + size_average=False, + ) + box_loss.append(box_loss_ / labels_.numel()) + + return torch.stack(classification_loss), torch.stack(box_loss) + +class RoIHeadsWILDS(RoIHeads): + def __init__(self, box_roi_pool, box_head, box_predictor, box_fg_iou_thresh, box_bg_iou_thresh,box_batch_size_per_image,box_positive_fraction,bbox_reg_weights,box_score_thresh,box_nms_thresh,box_detections_per_img): + + super().__init__(box_roi_pool, box_head, box_predictor, + box_fg_iou_thresh, box_bg_iou_thresh, + box_batch_size_per_image, box_positive_fraction, + bbox_reg_weights, + box_score_thresh, box_nms_thresh, box_detections_per_img) + + def forward(self, + features, # type: Dict[str, Tensor] + proposals, # type: List[Tensor] + image_shapes, # type: List[Tuple[int, int]] + targets=None # type: Optional[List[Dict[str, Tensor]]] + ): + # type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]] + """ + Arguments: + features (List[Tensor]) + proposals (List[Tensor[N, 4]]) + image_shapes (List[Tuple[H, W]]) + targets (List[Dict]) + """ + if targets is not None: + for t in targets: + # TODO: https://github.com/pytorch/pytorch/issues/26731 + floating_point_types = (torch.float, torch.double, torch.half) + assert t["boxes"].dtype in floating_point_types, 'target boxes must of float type' + assert t["labels"].dtype == torch.int64, 'target labels must of int64 type' + if self.has_keypoint(): + assert t["keypoints"].dtype == torch.float32, 'target keypoints must of float type' + + # here batch is maintained + if self.training: + proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets) + else: + labels = None + regression_targets = None + matched_idxs = None + + box_features = self.box_roi_pool(features, proposals, image_shapes) + + box_features = self.box_head(box_features) + + class_logits, box_regression = self.box_predictor(box_features) + result = torch.jit.annotate(List[Dict[str, torch.Tensor]], []) + losses = {} + + if self.training: + assert labels is not None and regression_targets is not None + + loss_classifier, loss_box_reg = fastrcnn_loss( + class_logits, box_regression, labels, regression_targets) + losses = { + "loss_classifier": loss_classifier, + "loss_box_reg": loss_box_reg + } + + boxes, scores, labels = self.postprocess_detections(class_logits, box_regression, proposals, image_shapes) + num_images = len(boxes) + for i in range(num_images): + result.append( + { + "boxes": boxes[i], + "labels": labels[i], + "scores": scores[i], + } + ) + + return result, losses + +def fasterrcnn_resnet50_fpn(pretrained=False, progress=True, + num_classes=91, pretrained_backbone=True, trainable_backbone_layers=3, **kwargs): + + assert trainable_backbone_layers <= 5 and trainable_backbone_layers >= 0 + # dont freeze any layers if pretrained model or backbone is not used + if not (pretrained or pretrained_backbone): + trainable_backbone_layers = 5 + if pretrained: + # no need to download the backbone if pretrained is set + pretrained_backbone = False + backbone = resnet_fpn_backbone('resnet50', pretrained_backbone, trainable_layers=trainable_backbone_layers) + model = FastWILDS(backbone, 91, **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls['fasterrcnn_resnet50_fpn_coco'], + progress=progress) + model.load_state_dict(state_dict) + + # get number of input features for the classifier + in_features = model.roi_heads.box_predictor.cls_score.in_features + # replace the pre-trained head with a new one + model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes+1) + + return model + +class FastWILDS(GeneralizedRCNN): + def __init__(self, backbone, num_classes=None, + # transform parameters + min_size=800, max_size=1333, + image_mean=None, image_std=None, + # RPN parameters + rpn_anchor_generator=None, rpn_head=None, + rpn_pre_nms_top_n_train=2000, rpn_pre_nms_top_n_test=1000, + rpn_post_nms_top_n_train=2000, rpn_post_nms_top_n_test=1000, + rpn_nms_thresh=0.7, + rpn_fg_iou_thresh=0.7, rpn_bg_iou_thresh=0.3, + rpn_batch_size_per_image=256, rpn_positive_fraction=0.5, + # Box parameters + box_roi_pool=None, box_head=None, box_predictor=None, + box_score_thresh=0.05, box_nms_thresh=0.5, box_detections_per_img=100, + box_fg_iou_thresh=0.5, box_bg_iou_thresh=0.5, + box_batch_size_per_image=512, box_positive_fraction=0.25, + bbox_reg_weights=None): + + if not hasattr(backbone, "out_channels"): + raise ValueError( + "backbone should contain an attribute out_channels " + "specifying the number of output channels (assumed to be the " + "same for all the levels)") + + assert isinstance(rpn_anchor_generator, (AnchorGenerator, type(None))) + assert isinstance(box_roi_pool, (MultiScaleRoIAlign, type(None))) + + if num_classes is not None: + if box_predictor is not None: + raise ValueError("num_classes should be None when box_predictor is specified") + else: + if box_predictor is None: + raise ValueError("num_classes should not be None when box_predictor " + "is not specified") + + out_channels = backbone.out_channels + + if rpn_anchor_generator is None: + anchor_sizes = ((32,), (64,), (128,), (256,), (512,)) + aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) + rpn_anchor_generator = AnchorGenerator( + anchor_sizes, aspect_ratios + ) + if rpn_head is None: + rpn_head = RPNHead( + out_channels, rpn_anchor_generator.num_anchors_per_location()[0] + ) + + rpn_pre_nms_top_n = dict(training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test) + rpn_post_nms_top_n = dict(training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test) + + rpn = RegionProposalNetworkWILDS( + rpn_anchor_generator, rpn_head, + rpn_fg_iou_thresh, rpn_bg_iou_thresh, + rpn_batch_size_per_image, rpn_positive_fraction, + rpn_pre_nms_top_n, rpn_post_nms_top_n, rpn_nms_thresh) + + if box_roi_pool is None: + box_roi_pool = MultiScaleRoIAlign( + featmap_names=['0', '1', '2', '3'], + output_size=7, + sampling_ratio=2) + + if box_head is None: + resolution = box_roi_pool.output_size[0] + representation_size = 1024 + box_head = TwoMLPHead( + out_channels * resolution ** 2, + representation_size) + + if box_predictor is None: + representation_size = 1024 + box_predictor = FastRCNNPredictor( + representation_size, + num_classes) + + roi_heads = RoIHeadsWILDS( + box_roi_pool, box_head, box_predictor, + box_fg_iou_thresh, box_bg_iou_thresh, + box_batch_size_per_image, box_positive_fraction, + bbox_reg_weights, + box_score_thresh, box_nms_thresh, box_detections_per_img) + + + image_mean = [0., 0., 0.] # small trick because images are already normalized + image_std = [1., 1., 1.] + transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std) + + super(FastWILDS, self).__init__(backbone, rpn, roi_heads, transform) + # Set your own forward pass + def forward(self, images, targets=None): + + + if self.training: + if targets is None: + raise ValueError("In training mode, targets should be passed") + assert targets is not None + for target in targets: + boxes = target["boxes"] + if isinstance(boxes, torch.Tensor): + if len(boxes.shape) != 2 or boxes.shape[-1] != 4: + raise ValueError("Expected target boxes to be a tensor" + "of shape [N, 4], got {:}.".format( + boxes.shape)) + else: + raise ValueError("Expected target boxes to be of type " + "Tensor, got {:}.".format(type(boxes))) + + original_image_sizes: List[Tuple[int, int]] = [] + for img in images: + val = img.shape[-2:] + assert len(val) == 2 + original_image_sizes.append((val[0], val[1])) + + + images, targets = self.transform(images, targets) + + # Check for degenerate boxes + # TODO: Move this to a function + if targets is not None: + for target_idx, target in enumerate(targets): + boxes = target["boxes"] + degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] + if degenerate_boxes.any(): + # print the first degenerate box + bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0] + degen_bb: List[float] = boxes[bb_idx].tolist() + raise ValueError("All bounding boxes should have positive height and width." + " Found invalid box {} for target at index {}." + .format(degen_bb, target_idx)) + + features = self.backbone(images.tensors) + if isinstance(features, torch.Tensor): + features = OrderedDict([('0', features)]) + + proposals, proposal_losses = self.rpn(images, features, targets) + + + detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets) + + + detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes) + + for idx, det in enumerate(detections): + det["losses"] = {} + for k,v in proposal_losses.items(): + det["losses"][k] = v[idx] + for k,v in detector_losses.items(): + det["losses"][k] = v[idx] + + + return detections + + + + + + +class FasterRCNNLoss(nn.Module): + def __init__(self,device): + self.device = device + super().__init__() + + def forward(self, outputs, targets): + + + # loss values are loss_classifier loss_box_reg loss_objectness": loss_objectness, loss_rpn_box_reg + try: + elementwise_loss = torch.stack([sum(v for v in item["losses"].values()) for item in outputs]) + except: + elementwise_loss = torch.ones(len(outputs)).to(self.device) + + + + return elementwise_loss diff --git a/examples/models/initializer.py b/examples/models/initializer.py index 4d414763..a9580c93 100644 --- a/examples/models/initializer.py +++ b/examples/models/initializer.py @@ -1,12 +1,7 @@ +import torch import torch.nn as nn -import torchvision -from models.bert.bert import BertClassifier, BertFeaturizer -from models.bert.distilbert import DistilBertClassifier, DistilBertFeaturizer -from models.resnet_multispectral import ResNet18 + from models.layers import Identity -from models.gnn import GINVirtual -from models.code_gpt import GPT2LMHeadLogit, GPT2FeaturizerLMHeadLogit -from transformers import GPT2Tokenizer def initialize_model(config, d_out, is_featurizer=False): """ @@ -36,6 +31,7 @@ def initialize_model(config, d_out, is_featurizer=False): name=config.model, d_out=d_out, **config.model_kwargs) + elif 'bert' in config.model: if is_featurizer: featurizer = initialize_bert_based_model(config, d_out, is_featurizer) @@ -43,21 +39,28 @@ def initialize_model(config, d_out, is_featurizer=False): model = (featurizer, classifier) else: model = initialize_bert_based_model(config, d_out) + elif config.model == 'resnet18_ms': # multispectral resnet 18 + from models.resnet_multispectral import ResNet18 if is_featurizer: featurizer = ResNet18(num_classes=None, **config.model_kwargs) classifier = nn.Linear(featurizer.d_out, d_out) model = (featurizer, classifier) else: model = ResNet18(num_classes=d_out, **config.model_kwargs) + elif config.model == 'gin-virtual': + from models.gnn import GINVirtual if is_featurizer: featurizer = GINVirtual(num_tasks=None, **config.model_kwargs) classifier = nn.Linear(featurizer.d_out, d_out) model = (featurizer, classifier) else: model = GINVirtual(num_tasks=d_out, **config.model_kwargs) + elif config.model == 'code-gpt-py': + from models.code_gpt import GPT2LMHeadLogit, GPT2FeaturizerLMHeadLogit + from transformers import GPT2Tokenizer name = 'microsoft/CodeGPT-small-py' tokenizer = GPT2Tokenizer.from_pretrained(name) if is_featurizer: @@ -69,14 +72,41 @@ def initialize_model(config, d_out, is_featurizer=False): else: model = GPT2LMHeadLogit.from_pretrained(name) model.resize_token_embeddings(len(tokenizer)) + elif config.model == 'logistic_regression': assert not is_featurizer, "Featurizer not supported for logistic regression" model = nn.Linear(out_features=d_out, **config.model_kwargs) + + elif config.model == 'fasterrcnn': + if is_featurizer: # TODO + raise NotImplementedError('Featurizer not implemented for detection yet') + else: + model = initialize_fasterrcnn_model(config, d_out) + model.needs_y = True + else: raise ValueError(f'Model: {config.model} not recognized.') + + # The `needs_y` attribute specifies whether the model's forward function + # needs to take in both (x, y). + # If False, Algorithm.process_batch will call model(x). + # If True, Algorithm.process_batch() will call model(x, y) during training, + # and model(x, None) during eval. + if not hasattr(model, 'needs_y'): + # Sometimes model is a tuple of (featurizer, classifier) + if isinstance(model, tuple): + for submodel in model: + submodel.needs_y = False + else: + model.needs_y = False + return model + def initialize_bert_based_model(config, d_out, is_featurizer=False): + from models.bert.bert import BertClassifier, BertFeaturizer + from models.bert.distilbert import DistilBertClassifier, DistilBertFeaturizer + if config.model == 'bert-base-uncased': if is_featurizer: model = BertFeaturizer.from_pretrained(config.model, **config.model_kwargs) @@ -98,6 +128,8 @@ def initialize_bert_based_model(config, d_out, is_featurizer=False): return model def initialize_torchvision_model(name, d_out, **kwargs): + import torchvision + # get constructor and last layer names if name == 'wideresnet50': constructor_name = 'wide_resnet50_2' @@ -123,3 +155,13 @@ def initialize_torchvision_model(name, d_out, **kwargs): model.d_out = d_out setattr(model, last_layer_name, last_layer) return model + + +def initialize_fasterrcnn_model(config, d_out): + + from models.detection.fasterrcnn import fasterrcnn_resnet50_fpn + + # load a model pre-trained pre-trained on COCO + model = fasterrcnn_resnet50_fpn(pretrained=config.model_kwargs["pretrained"],num_classes=d_out) + + return model diff --git a/examples/run_expt.py b/examples/run_expt.py index 173603ab..19f6d570 100644 --- a/examples/run_expt.py +++ b/examples/run_expt.py @@ -8,6 +8,7 @@ import sys from collections import defaultdict +sys.path.insert(1, os.path.join(sys.path[0], '..')) import wilds from wilds.common.data_loaders import get_train_loader, get_eval_loader from wilds.common.grouper import CombinatorialGrouper @@ -19,8 +20,11 @@ from configs.utils import populate_defaults import configs.supported as supported +import torch.multiprocessing + def main(): - ''' set default hyperparams in default_hyperparams.py ''' + + ''' to see default hyperparams for each dataset/model, look at configs/ ''' parser = argparse.ArgumentParser() # Required arguments @@ -61,6 +65,8 @@ def main(): # Objective parser.add_argument('--loss_function', choices = supported.losses) + parser.add_argument('--loss_kwargs', nargs='*', action=ParseKwargs, default={}, + help='keyword arguments for loss initialization passed as key1=value1 key2=value2') # Algorithm parser.add_argument('--groupby_fields', nargs='+') @@ -112,10 +118,15 @@ def main(): config = parser.parse_args() config = populate_defaults(config) - # set device + # For the GWHD dataset, we need to change the multiprocessing strategy or there will be + # too many open file descriptors + if config.dataset == 'gwhd': + torch.multiprocessing.set_sharing_strategy('file_system') + + # Set device config.device = torch.device("cuda:" + str(config.device)) if torch.cuda.is_available() else torch.device("cpu") - ## Initialize logs + # Initialize logs if os.path.exists(config.log_dir) and config.resume: resume=True mode='a' @@ -269,12 +280,15 @@ def main(): epoch = best_epoch else: epoch = config.eval_epoch + if epoch == best_epoch: + is_best = True evaluate( algorithm=algorithm, datasets=datasets, epoch=epoch, general_logger=logger, - config=config) + config=config, + is_best=is_best) logger.close() for split in datasets: diff --git a/examples/train.py b/examples/train.py index 774f3d1e..93cc076e 100644 --- a/examples/train.py +++ b/examples/train.py @@ -1,8 +1,7 @@ import os from tqdm import tqdm import torch -from utils import save_model, save_pred, get_pred_prefix, get_model_prefix -import torch.autograd.profiler as profiler +from utils import save_model, save_pred, get_pred_prefix, get_model_prefix, detach_and_clone, collate_list from configs.supported import process_outputs_functions def run_epoch(algorithm, dataset, general_logger, epoch, config, train): @@ -34,23 +33,24 @@ def run_epoch(algorithm, dataset, general_logger, epoch, config, train): # These tensors are already detached, but we need to clone them again # Otherwise they don't get garbage collected properly in some versions - # The subsequent detach is just for safety + # The extra detach is just for safety # (they should already be detached in batch_results) - epoch_y_true.append(batch_results['y_true'].clone().detach()) - y_pred = batch_results['y_pred'].clone().detach() + epoch_y_true.append(detach_and_clone(batch_results['y_true'])) + y_pred = detach_and_clone(batch_results['y_pred']) if config.process_outputs_function is not None: y_pred = process_outputs_functions[config.process_outputs_function](y_pred) epoch_y_pred.append(y_pred) - epoch_metadata.append(batch_results['metadata'].clone().detach()) + epoch_metadata.append(detach_and_clone(batch_results['metadata'])) if train and (batch_idx+1) % config.log_every==0: log_results(algorithm, dataset, general_logger, epoch, batch_idx) batch_idx += 1 - epoch_y_pred = torch.cat(epoch_y_pred) - epoch_y_true = torch.cat(epoch_y_true) - epoch_metadata = torch.cat(epoch_metadata) + epoch_y_pred = collate_list(epoch_y_pred) + epoch_y_true = collate_list(epoch_y_true) + epoch_metadata = collate_list(epoch_metadata) + results, results_str = dataset['dataset'].eval( epoch_y_pred, epoch_y_true, @@ -112,7 +112,7 @@ def train(algorithm, datasets, general_logger, config, epoch_offset, best_val_me general_logger.write('\n') -def evaluate(algorithm, datasets, epoch, general_logger, config): +def evaluate(algorithm, datasets, epoch, general_logger, config, is_best): algorithm.eval() for split, dataset in datasets.items(): if (not config.evaluate_all_splits) and (split not in config.eval_splits): @@ -123,17 +123,20 @@ def evaluate(algorithm, datasets, epoch, general_logger, config): iterator = tqdm(dataset['loader']) if config.progress_bar else dataset['loader'] for batch in iterator: batch_results = algorithm.evaluate(batch) - epoch_y_true.append(batch_results['y_true'].clone().detach()) - y_pred = batch_results['y_pred'].clone().detach() + epoch_y_true.append(detach_and_clone(batch_results['y_true'])) + y_pred = detach_and_clone(batch_results['y_pred']) if config.process_outputs_function is not None: y_pred = process_outputs_functions[config.process_outputs_function](y_pred) epoch_y_pred.append(y_pred) - epoch_metadata.append(batch_results['metadata'].clone().detach()) + epoch_metadata.append(detach_and_clone(batch_results['metadata'])) + epoch_y_pred = collate_list(epoch_y_pred) + epoch_y_true = collate_list(epoch_y_true) + epoch_metadata = collate_list(epoch_metadata) results, results_str = dataset['dataset'].eval( - torch.cat(epoch_y_pred), - torch.cat(epoch_y_true), - torch.cat(epoch_metadata)) + epoch_y_pred, + epoch_y_true, + epoch_metadata) results['epoch'] = epoch dataset['eval_logger'].log(results) @@ -142,7 +145,7 @@ def evaluate(algorithm, datasets, epoch, general_logger, config): # Skip saving train preds, since the train loader generally shuffles the data if split != 'train': - save_pred_if_needed(y_pred, dataset, epoch, config, is_best=False, force_save=True) + save_pred_if_needed(epoch_y_pred, dataset, epoch, config, is_best, force_save=True) def log_results(algorithm, dataset, general_logger, epoch, batch_idx): @@ -160,11 +163,11 @@ def save_pred_if_needed(y_pred, dataset, epoch, config, is_best, force_save=Fals if config.save_pred: prefix = get_pred_prefix(dataset, config) if force_save or (config.save_step is not None and (epoch + 1) % config.save_step == 0): - save_pred(y_pred, prefix + f'epoch:{epoch}_pred.csv') - if config.save_last: - save_pred(y_pred, prefix + f'epoch:last_pred.csv') + save_pred(y_pred, prefix + f'epoch:{epoch}_pred') + if (not force_save) and config.save_last: + save_pred(y_pred, prefix + f'epoch:last_pred') if config.save_best and is_best: - save_pred(y_pred, prefix + f'epoch:best_pred.csv') + save_pred(y_pred, prefix + f'epoch:best_pred') def save_model_if_needed(algorithm, dataset, epoch, config, is_best, best_val_metric): diff --git a/examples/transforms.py b/examples/transforms.py index bafbd42f..bbcd88a4 100644 --- a/examples/transforms.py +++ b/examples/transforms.py @@ -3,6 +3,10 @@ import torch def initialize_transform(transform_name, config, dataset): + """ + Transforms should take in a single (x, y) + and return (transformed_x, transformed_y). + """ if transform_name is None: return None elif transform_name=='bert': @@ -16,6 +20,11 @@ def initialize_transform(transform_name, config, dataset): else: raise ValueError(f"{transform_name} not recognized") +def transform_input_only(input_transform): + def transform(x, y): + return input_transform(x), y + return transform + def initialize_bert_transform(config): assert 'bert' in config.model assert config.max_token_length is not None @@ -41,7 +50,7 @@ def transform(text): dim=2) x = torch.squeeze(x, dim=0) # First shape dim is always 1 return x - return transform + return transform_input_only(transform) def getBertTokenizer(model): if model == 'bert-base-uncased': @@ -65,7 +74,7 @@ def initialize_image_base_transform(config, dataset): transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ] transform = transforms.Compose(transform_steps) - return transform + return transform_input_only(transform) def initialize_image_resize_and_center_crop_transform(config, dataset): """ @@ -84,7 +93,7 @@ def initialize_image_resize_and_center_crop_transform(config, dataset): transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) - return transform + return transform_input_only(transform) def initialize_poverty_train_transform(): transforms_ls = [ @@ -99,5 +108,6 @@ def transform_rgb(img): # bgr to rgb and back to bgr img[:3] = rgb_transform(img[:3][[2,1,0]])[[2,1,0]] return img + transform = transforms.Lambda(lambda x: transform_rgb(x)) - return transform + return transform_input_only(transform) diff --git a/examples/utils.py b/examples/utils.py index 89780d62..73fa1b12 100644 --- a/examples/utils.py +++ b/examples/utils.py @@ -176,9 +176,16 @@ def initialize_wandb(config): project=f"wilds") wandb.config.update(config) -def save_pred(y_pred, csv_path): - df = pd.DataFrame(y_pred.numpy()) - df.to_csv(csv_path, index=False, header=False) +def save_pred(y_pred, path_prefix): + # Single tensor + if torch.is_tensor(y_pred): + df = pd.DataFrame(y_pred.numpy()) + df.to_csv(path_prefix + '.csv', index=False, header=False) + # Dictionary + elif isinstance(y_pred, dict) or isinstance(y_pred, list): + torch.save(y_pred, path_prefix + '.pth') + else: + raise TypeError("Invalid type for save_pred") def get_replicate_str(dataset, config): if dataset['dataset'].dataset_name == 'poverty': @@ -203,3 +210,59 @@ def get_model_prefix(dataset, config): config.log_dir, f"{dataset_name}_{replicate_str}_") return prefix + +def move_to(obj, device): + if isinstance(obj, dict): + return {k: move_to(v, device) for k, v in obj.items()} + elif isinstance(obj, list): + return [move_to(v, device) for v in obj] + elif isinstance(obj, float) or isinstance(obj, int): + return obj + else: + # Assume obj is a Tensor or other type + # (like Batch, for MolPCBA) that supports .to(device) + return obj.to(device) + +def detach_and_clone(obj): + if torch.is_tensor(obj): + return obj.detach().clone() + elif isinstance(obj, dict): + return {k: detach_and_clone(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [detach_and_clone(v) for v in obj] + elif isinstance(obj, float) or isinstance(obj, int): + return obj + else: + raise TypeError("Invalid type for detach_and_clone") + +def collate_list(vec): + """ + If vec is a list of Tensors, it concatenates them all along the first dimension. + + If vec is a list of lists, it joins these lists together, but does not attempt to + recursively collate. This allows each element of the list to be, e.g., its own dict. + + If vec is a list of dicts (with the same keys in each dict), it returns a single dict + with the same keys. For each key, it recursively collates all entries in the list. + """ + if not isinstance(vec, list): + raise TypeError("collate_list must take in a list") + elem = vec[0] + if torch.is_tensor(elem): + return torch.cat(vec) + elif isinstance(elem, list): + return [obj for sublist in vec for obj in sublist] + elif isinstance(elem, dict): + return {k: collate_list([d[k] for d in vec]) for k in elem} + else: + raise TypeError("Elements of the list to collate must be tensors or dicts.") + +def remove_key(key): + """ + Returns a function that strips out a key from a dict. + """ + def remove(d): + if not isinstance(d, dict): + raise TypeError("remove_key must take in a dict") + return {k: v for (k,v) in d.items() if k != key} + return remove diff --git a/setup.py b/setup.py index 9cd1f596..1fb6a2e5 100644 --- a/setup.py +++ b/setup.py @@ -30,6 +30,7 @@ 'tqdm>=4.53.0', 'outdated>=0.2.0', 'pytz>=2020.4', + 'torchvision==0.8.2' ], license='MIT', packages=setuptools.find_packages(exclude=['dataset_preprocessing', 'examples', 'examples.models', 'examples.models.bert']), diff --git a/wilds/__init__.py b/wilds/__init__.py index 77f0ad5a..fe1bed0a 100644 --- a/wilds/__init__.py +++ b/wilds/__init__.py @@ -10,6 +10,7 @@ 'poverty', 'fmow', 'py150', + 'gwhd', ] additional_datasets = [ diff --git a/wilds/common/data_loaders.py b/wilds/common/data_loaders.py index b806832a..0cb23cee 100644 --- a/wilds/common/data_loaders.py +++ b/wilds/common/data_loaders.py @@ -4,20 +4,20 @@ from torch.utils.data.sampler import WeightedRandomSampler, SubsetRandomSampler from wilds.common.utils import get_counts, split_into_groups -def get_train_loader(loader, dataset, batch_size, +def get_train_loader(loader, dataset, batch_size, uniform_over_groups=None, grouper=None, distinct_groups=True, n_groups_per_batch=None, **loader_kwargs): """ Constructs and returns the data loader for training. Args: - loader (str): Loader type. 'standard' for standard loaders and 'group' for group loaders, - which first samples groups and then samples a fixed number of examples belonging + which first samples groups and then samples a fixed number of examples belonging to each group. - dataset (WILDSDataset or WILDSSubset): Data - batch_size (int): Batch size - - uniform_over_groups (None or bool): Whether to sample the groups uniformly or according to the + - uniform_over_groups (None or bool): Whether to sample the groups uniformly or according to the natural data distribution. - Setting to None applies the defaults for each type of loaders. - For standard loaders, the default is False. For group loaders, + Setting to None applies the defaults for each type of loaders. + For standard loaders, the default is False. For group loaders, the default is True. - grouper (Grouper): Grouper used for group loaders or for uniform_over_groups=True - distinct_groups (bool): Whether to sample distinct_groups within each minibatch for group loaders. @@ -82,7 +82,7 @@ def get_eval_loader(loader, dataset, batch_size, grouper=None, **loader_kwargs): """ Constructs and returns the data loader for evaluation. Args: - - loader (str): Loader type. 'standard' for standard loaders. + - loader (str): Loader type. 'standard' for standard loaders. - dataset (WILDSDataset or WILDSSubset): Data - batch_size (int): Batch size - loader_kwargs: kwargs passed into torch DataLoader initialization. @@ -104,7 +104,6 @@ class GroupSampler: then sampling data from those groups. It drops the last batch if it's incomplete. """ - def __init__(self, group_ids, batch_size, n_groups_per_batch, uniform_over_groups, distinct_groups): @@ -129,16 +128,13 @@ def __init__(self, group_ids, batch_size, n_groups_per_batch, self.group_prob = unique_counts.numpy() / unique_counts.numpy().sum() def __iter__(self): - for batch_id in range(self.num_batches): - # Note that we are selecting group indices rather than groups groups_for_batch = np.random.choice( len(self.unique_groups), size=self.n_groups_per_batch, replace=(not self.distinct_groups), p=self.group_prob) - sampled_ids = [ np.random.choice( self.group_indices[group], @@ -149,7 +145,6 @@ def __iter__(self): # Flatten sampled_ids = np.concatenate(sampled_ids) - yield sampled_ids def __len__(self): diff --git a/wilds/common/metrics/all_metrics.py b/wilds/common/metrics/all_metrics.py index 0f5d7eb1..ec93cc9f 100644 --- a/wilds/common/metrics/all_metrics.py +++ b/wilds/common/metrics/all_metrics.py @@ -1,10 +1,13 @@ import torch import torch.nn as nn +from torchvision.ops.boxes import box_iou +from torchvision.models.detection._utils import Matcher +from torchvision.ops import nms, box_convert import numpy as np import torch.nn.functional as F from wilds.common.metrics.metric import Metric, ElementwiseMetric, MultiTaskMetric from wilds.common.metrics.loss import ElementwiseLoss -from wilds.common.utils import avg_over_groups, minimum, maximum +from wilds.common.utils import avg_over_groups, minimum, maximum, get_counts import sklearn.metrics from scipy.stats import pearsonr @@ -19,7 +22,7 @@ def binary_logits_to_score(logits): def multiclass_logits_to_pred(logits): """ - Takes multi-class logits of size (batch_size, ..., n_classes) and returns predictions + Takes multi-class logits of size (batch_size, ..., n_classes) and returns predictions by taking an argmax at the last dimension """ assert logits.dim() > 1 @@ -142,3 +145,78 @@ def _compute(self, y_pred, y_true): def worst(self, metrics): return minimum(metrics) + +class DummyMetric(Metric): + """ + For testing purposes. This Metric always returns -1. + """ + def __init__(self, prediction_fn=None, name=None): + self.prediction_fn = prediction_fn + if name is None: + name = 'dummy' + super().__init__(name=name) + + def _compute(self, y_pred, y_true): + return torch.tensor(-1) + + def _compute_group_wise(self, y_pred, y_true, g, n_groups): + group_metrics = torch.ones(n_groups, device=g.device) * -1 + group_counts = get_counts(g, n_groups) + worst_group_metric = self.worst(group_metrics) + return group_metrics, group_counts, worst_group_metric + + def worst(self, metrics): + return minimum(metrics) + +class DetectionAccuracy(ElementwiseMetric): + """ + Given a specific Intersection over union threshold, + determine the accuracy achieved for a one-class detector + """ + def __init__(self, iou_threshold=0.5, score_threshold=0.5, name=None): + self.iou_threshold = iou_threshold + self.score_threshold = score_threshold + if name is None: + name = "detection_acc" + super().__init__(name=name) + + def _compute_element_wise(self, y_pred, y_true): + batch_results = [] + for src_boxes, target in zip(y_true, y_pred): + target_boxes = target["boxes"] + target_scores = target["scores"] + + pred_boxes = target_boxes[target_scores > self.score_threshold] + det_accuracy = torch.mean(torch.stack([ self._accuracy(src_boxes["boxes"],pred_boxes,iou_thr) for iou_thr in np.arange(0.5,0.51,0.05)])) + batch_results.append(det_accuracy) + + return torch.tensor(batch_results) + + def _accuracy(self, src_boxes,pred_boxes , iou_threshold): + total_gt = len(src_boxes) + total_pred = len(pred_boxes) + if total_gt > 0 and total_pred > 0: + # Define the matcher and distance matrix based on iou + matcher = Matcher(iou_threshold,iou_threshold,allow_low_quality_matches=False) + match_quality_matrix = box_iou(src_boxes,pred_boxes) + results = matcher(match_quality_matrix) + true_positive = torch.count_nonzero(results.unique() != -1) + matched_elements = results[results > -1] + #in Matcher, a pred element can be matched only twice + false_positive = ( + torch.count_nonzero(results == -1) + + (len(matched_elements) - len(matched_elements.unique())) + ) + false_negative = total_gt - true_positive + acc = true_positive / ( true_positive + false_positive + false_negative ) + return true_positive / ( true_positive + false_positive + false_negative ) + elif total_gt == 0: + if total_pred > 0: + return torch.tensor(0.) + else: + return torch.tensor(1.) + elif total_gt > 0 and total_pred == 0: + return torch.tensor(0.) + + def worst(self, metrics): + return minimum(metrics) diff --git a/wilds/common/metrics/metric.py b/wilds/common/metrics/metric.py index 9c4372b0..89582577 100644 --- a/wilds/common/metrics/metric.py +++ b/wilds/common/metrics/metric.py @@ -1,5 +1,5 @@ import numpy as np -from wilds.common.utils import avg_over_groups, get_counts +from wilds.common.utils import avg_over_groups, get_counts, numel import torch class Metric: @@ -82,7 +82,7 @@ def compute(self, y_pred, y_true, return_dict=True): Output (return_dict=True): - results (dict): Dictionary of results, mapping metric.agg_metric_field to avg_metric """ - if y_true.numel()==0: + if numel(y_true) == 0: agg_metric = torch.tensor(0., device=y_true.device) else: agg_metric = self._compute(y_pred, y_true) @@ -133,6 +133,7 @@ def _compute_group_wise(self, y_pred, y_true, g, n_groups): self._compute( y_pred[g == group_idx], y_true[g == group_idx])) + group_metrics = torch.stack(group_metrics) worst_group_metric = self.worst(group_metrics[group_counts>0]) diff --git a/wilds/common/utils.py b/wilds/common/utils.py index 7854393a..ebf2f41d 100644 --- a/wilds/common/utils.py +++ b/wilds/common/utils.py @@ -81,8 +81,9 @@ def avg_over_groups(v, g, n_groups): group_avgs (Tensor): Vector of length num_groups group_counts (Tensor) """ + + assert v.device==g.device - device = v.device assert v.numel()==g.numel() group_count = get_counts(g, n_groups) group_avgs = torch_scatter.scatter(src=v, index=g, dim_size=n_groups, reduce='mean') @@ -113,7 +114,6 @@ def subsample_idxs(idxs, num=5000, take_rest=False, seed=None): idxs = idxs[:num] return idxs - def shuffle_arr(arr, seed=None): seed = (seed + 548207) if seed is not None else None rng = np.random.default_rng(seed) @@ -126,3 +126,11 @@ def threshold_at_recall(y_pred, y_true, global_recall=60): """ Calculate the model threshold to use to achieve a desired global_recall level. Assumes that y_true is a vector of the true binary labels.""" return np.percentile(y_pred[y_true == 1], 100-global_recall) + +def numel(obj): + if torch.is_tensor(obj): + return obj.numel() + elif isinstance(obj, list): + return len(obj) + else: + raise TypeError("Invalid type for numel") diff --git a/wilds/datasets/gwhd_dataset.py b/wilds/datasets/gwhd_dataset.py new file mode 100644 index 00000000..221b480a --- /dev/null +++ b/wilds/datasets/gwhd_dataset.py @@ -0,0 +1,146 @@ +import numpy as np +import pandas as pd +import torch +from pathlib import Path +from PIL import Image +from wilds.datasets.wilds_dataset import WILDSDataset +from wilds.common.grouper import CombinatorialGrouper +from wilds.common.metrics.all_metrics import DetectionAccuracy + + +def _collate_fn(batch): + """ + Stack x (batch[0]) and metadata (batch[2]), but not y. + originally, batch = (item1, item2, item3, item4) + after zip, batch = [(item1[0], item2[0], ..), ..] + """ + batch = list(zip(*batch)) + batch[0] = torch.stack(batch[0]) + batch[1] = list(batch[1]) + batch[2] = torch.stack(batch[2]) + return tuple(batch) + +class GWHDDataset(WILDSDataset): + """ + The GWHD-wilds wheat head localization dataset. + This is a modified version of the original Global Wheat Head Dataset. + This dataset is not part of the official WILDS benchmark. + We provide it for convenience and to reproduce observations discussed in the WILDS paper. + Supported `split_scheme`: + 'official' for WILDS related tasks. + To reproduce the baseline, several splits are needed: + - to train a model on train domains and test against a all test split: 'train_in-dist' + - "benchmark_biased" ; "benchmark_in-dist" + Input (x): + 1024x1024 RGB images of wheat field canopy between flowering and ripening. + Output (y): + y is a nx4-dimensional vector where each line represents a box coordinate (x_min,y_min,x_max,y_max) + Metadata: + Each image is annotated with the ID of the domain it came from (integer from 0 to 10). + Website: + http://www.global-wheat.com/ + Original publication: + @article{david_global_2020, + title = {Global {Wheat} {Head} {Detection} ({GWHD}) {Dataset}: {A} {Large} and {Diverse} {Dataset} of {High}-{Resolution} {RGB}-{Labelled} {Images} to {Develop} and {Benchmark} {Wheat} {Head} {Detection} {Methods}}, + volume = {2020}, + url = {https://doi.org/10.34133/2020/3521852}, + doi = {10.34133/2020/3521852}, + journal = {Plant Phenomics}, + author = {David, Etienne and Madec, Simon and Sadeghi-Tehran, Pouria and Aasen, Helge and Zheng, Bangyou and Liu, Shouyang and Kirchgessner, Norbert and Ishikawa, Goro and Nagasawa, Koichi and Badhon, Minhajul A. and Pozniak, Curtis and de Solan, Benoit and Hund, Andreas and Chapman, Scott C. and Baret, Frédéric and Stavness, Ian and Guo, Wei}, + month = aug, + year = {2020}, + note = {Publisher: AAAS}, + pages = {3521852}, + } + License: + This dataset is distributed under the MIT license. + https://github.com/snap-stanford/ogb/blob/master/LICENSE + """ + + _dataset_name = 'gwhd' + _versions_dict = { + '2.0': { + 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x42fa9775eacc453489a428abd59a437d/contents/blob/', + 'compressed_size': None}} + + def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'): + + self._version = version + self._data_dir = self.initialize_data_dir(root_dir, download) + self._original_resolution = (1024, 1024) + self.root = Path(self.data_dir) + self._is_detection = True + self._is_classification = False + self._y_size = None + self._n_classes = 1 + + self._split_scheme = split_scheme + + # Get filenames + + if split_scheme =="official": + train_data_df = pd.read_csv(self.root / f'official_train.csv') + val_data_df = pd.read_csv(self.root / f'official_val.csv') + test_data_df = pd.read_csv(self.root / f'official_test.csv') + + elif split_scheme == "benchmark_biased": + train_data_df = pd.read_csv(self.root / f'official_train.csv') + val_data_df = pd.read_csv(self.root / f'official_val.csv') + test_data_df = pd.read_csv(self.root / f'in-dist_test.csv') + + elif split_scheme == "benchmark_in-dist": + train_data_df = pd.read_csv(self.root / f'in-dist_train.csv') + val_data_df = pd.read_csv(self.root / f'official_val.csv') + test_data_df = pd.read_csv(self.root / f'in-dist_test.csv') + + + self._image_array = [] + self._split_array, self._y_array, self._metadata_array = [], [], [] + + for i, df in enumerate([train_data_df, val_data_df, test_data_df]): + self._image_array.extend(list(df['image'].values)) + labels = list(df['labels'].values) + self._split_array.extend([i] * len(labels)) + + labels = [{ + "boxes": torch.stack([ + torch.tensor([int(float(i)) for i in box.split(" ")]) + for box in boxes.split(";") + ]), + "labels": torch.tensor([1]*len(list(boxes.split(";")))).long() + } if type(boxes) != float else { + "boxes": torch.empty(0,4), + "labels": torch.empty(0,dtype=torch.long) + } for boxes in labels] + + self._y_array.extend(labels) + self._metadata_array.extend(list(df['group'].values)) + + self._split_array = np.array(self._split_array) + + self._metadata_array = torch.tensor(self._metadata_array, + dtype=torch.long).unsqueeze(1) + self._metadata_fields = ['location'] + + self._eval_grouper = CombinatorialGrouper( + dataset=self, + groupby_fields=['location']) + + self._metric = DetectionAccuracy() # TODO + self._collate = _collate_fn + + super().__init__(root_dir, download, split_scheme) + + def get_input(self, idx): + """ + Returns x for a given idx. + """ + img_filename = self.root / "images" / self._image_array[idx] + x = Image.open(img_filename) + return x + + def eval(self, y_pred, y_true, metadata): + return self.standard_group_eval( + self._metric, + self._eval_grouper, + y_pred, y_true, metadata) diff --git a/wilds/datasets/wilds_dataset.py b/wilds/datasets/wilds_dataset.py index 1f8bf21a..8812f957 100644 --- a/wilds/datasets/wilds_dataset.py +++ b/wilds/datasets/wilds_dataset.py @@ -95,8 +95,8 @@ def check_init(self): assert 'train' in self.split_dict assert 'val' in self.split_dict - # Check that required arrays are Tensors - assert isinstance(self.y_array, torch.Tensor), 'y_array must be a torch.Tensor' + # Check the form of the required arrays + assert (isinstance(self.y_array, torch.Tensor) or isinstance(self.y_array, list)) assert isinstance(self.metadata_array, torch.Tensor), 'metadata_array must be a torch.Tensor' # Check that dimensions match @@ -106,6 +106,10 @@ def check_init(self): # Check metadata assert len(self.metadata_array.shape) == 2 assert len(self.metadata_fields) == self.metadata_array.shape[1] + + # Check that it is not both classification and detection + assert not (self.is_classification and self.is_detection) + # For convenience, include y in metadata_fields if y_size == 1 if self.y_size == 1: assert 'y' in self.metadata_fields @@ -242,9 +246,15 @@ def n_classes(self): def is_classification(self): """ Boolean. True if the task is classification, and false otherwise. - Used for logging purposes. """ - return (self.n_classes is not None) + return getattr(self, '_is_classification', (self.n_classes is not None)) + + @property + def is_detection(self): + """ + Boolean. True if the task is detection, and false otherwise. + """ + return getattr(self, '_is_detection', False) @property def metadata_fields(self): @@ -443,7 +453,7 @@ def __init__(self, dataset, indices, transform): def __getitem__(self, idx): x, y, metadata = self.dataset[self.indices[idx]] if self.transform is not None: - x = self.transform(x) + x, y = self.transform(x, y) return x, y, metadata def __len__(self): diff --git a/wilds/get_dataset.py b/wilds/get_dataset.py index 1073100f..cfa5f2c7 100644 --- a/wilds/get_dataset.py +++ b/wilds/get_dataset.py @@ -55,7 +55,7 @@ def get_dataset(dataset, version=None, **dataset_kwargs): elif dataset == 'poverty': if version == '1.0': from wilds.datasets.archive.poverty_v1_0_dataset import PovertyMapDataset - else: + else: from wilds.datasets.poverty_dataset import PovertyMapDataset return PovertyMapDataset(version=version, **dataset_kwargs) @@ -77,3 +77,7 @@ def get_dataset(dataset, version=None, **dataset_kwargs): elif dataset == 'sqf': from wilds.datasets.sqf_dataset import SQFDataset return SQFDataset(version=version, **dataset_kwargs) + + elif dataset == 'gwhd': + from wilds.datasets.gwhd_dataset import GWHDDataset + return GWHDDataset(version=version, **dataset_kwargs) From 81da4eb14b95ffc3b9056eb7ae9ba40f2de13e47 Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Sun, 9 May 2021 15:31:54 -0700 Subject: [PATCH 176/244] small cleanups --- README.md | 13 ++- examples/models/detection/fasterrcnn.py | 48 ++--------- wilds/__init__.py | 2 +- wilds/datasets/gwhd_dataset.py | 105 +++++++++++++----------- 4 files changed, 73 insertions(+), 95 deletions(-) diff --git a/README.md b/README.md index b5d879f2..ce05cae2 100644 --- a/README.md +++ b/README.md @@ -50,7 +50,7 @@ pip install -e . - torch>=1.7.0 - torch-scatter>=2.0.5 - torch-geometric>=1.6.1 -- tqdm>=4.53.0 +- tqdm>=4.53.0 Running `pip install wilds` or `pip install -e .` will automatically check for and install all of these requirements except for the `torch-scatter` and `torch-geometric` packages, which require a [quick manual install](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html#installation-via-binaries). @@ -83,6 +83,7 @@ python examples/run_expt.py --dataset civilcomments --algorithm groupDRO --root_ The scripts are set up to facilitate general-purpose algorithm development: new algorithms can be added to `examples/algorithms` and then run on all of the WILDS datasets using the default models. +### Downloading and training on the WILDS datasets The first time you run these scripts, you might need to download the datasets. You can do so with the `--download` argument, for example: ``` python examples/run_expt.py --dataset civilcomments --algorithm groupDRO --root_dir data --download @@ -113,16 +114,20 @@ While the `camelyon17` dataset is small and fast to train on, we advise against The image datasets (`iwildcam`, `camelyon17`, `fmow`, and `poverty`) tend to have high disk I/O usage. If training time is much slower for you than the approximate times listed above, consider checking if I/O is a bottleneck (e.g., by moving to a local disk if you are using a network drive, or by increasing the number of data loader workers). To speed up training, you could also disable evaluation at each epoch or for all splits by toggling `--evaluate_all_splits` and related arguments. +### Evaluating trained models We also provide an evaluation script that aggregates prediction CSV files for different replicates and reports on their combined evaluation. To use this, run: ```bash python examples/evaluate.py --root-dir ``` -where `` is the path to your predictions directory, `` is where the results JSON will be -outputted and `` is the dataset directory. The predictions directory should have a subdirectory for each dataset -(e.g. `iwildcam`) containing prediction CSV files to evaluate; see our [submission guidelines](https://wilds.stanford.edu/submit/) for the format. The evaluation script will skip over any datasets that has missing prediction files. Any dataset not in `` will be downloaded to ``. +where `` is the path to your predictions directory, `` is where the results JSON will be writte, and `` is the dataset root directory. +The predictions directory should have a subdirectory for each dataset +(e.g. `iwildcam`) containing prediction CSV files to evaluate; see our [submission guidelines](https://wilds.stanford.edu/submit/) for the format. +The evaluation script will skip over any datasets that has missing prediction files. +Any dataset not in `` will be downloaded to ``. +### Reproducibility We have an [executable version](https://wilds.stanford.edu/codalab) of our paper on CodaLab that contains the exact commands, code, and data for the experiments reported in our paper, which rely on these scripts. Trained model weights for all datasets can also be found there. diff --git a/examples/models/detection/fasterrcnn.py b/examples/models/detection/fasterrcnn.py index 6abb6b16..59d736a1 100644 --- a/examples/models/detection/fasterrcnn.py +++ b/examples/models/detection/fasterrcnn.py @@ -1,13 +1,12 @@ """ -This module contains all the necessary modifications to adapt "Faster-RCNN" of the torchvision library -in order to be able to calculate the loss per image +This module adapts Faster-RCNN from the torchvision library to compute per-image losses, +instead of the default per-batch losses. +It is based on the version from torchvision==0.8.2, +and has not been tested on other versions. -It has been developped from torchvision=0.8.2 and did not has been tested on other versions - -All credits : +The torchvision library is distributed under the BSD 3-Clause License: https://github.com/pytorch/vision/blob/master/LICENSE https://github.com/pytorch/vision/tree/master/torchvision/models/detection - """ import torch @@ -20,30 +19,22 @@ from typing import Tuple, List, Dict, Optional, Union from torch import nn - +from torch.nn import functional as F import torchvision from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, FasterRCNN from torchvision.models.detection.backbone_utils import resnet_fpn_backbone from torchvision.models.utils import load_state_dict_from_url - - from torchvision.ops import misc as misc_nn_ops from torchvision.ops import MultiScaleRoIAlign - - +from torchvision.models.detection import _utils as det_utils from torchvision.models.detection.anchor_utils import AnchorGenerator from torchvision.models.detection.generalized_rcnn import GeneralizedRCNN from torchvision.models.detection.faster_rcnn import TwoMLPHead - from torchvision.models.detection.rpn import RPNHead, RegionProposalNetwork, concat_box_prediction_layers,permute_and_flatten from torchvision.models.detection.roi_heads import RoIHeads - -from torchvision.models.detection import _utils as det_utils -from torch.nn import functional as F from torchvision.models.detection.transform import GeneralizedRCNNTransform - model_urls = { 'fasterrcnn_resnet50_fpn_coco': 'https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth', @@ -53,7 +44,6 @@ 'https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth' } - def batch_concat_box_prediction_layers(box_cls, box_regression): # type: (List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor] box_cls_flattened = [] @@ -133,7 +123,6 @@ def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets) for objectness_, regression_targets_,labels_,objectness_,pred_bbox_deltas_ in zip(objectness,regression_targets,labels,objectness,pred_bbox_deltas): sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(torch.unsqueeze(labels_,dim=0)) - sampled_pos_inds = torch.where(torch.cat(sampled_pos_inds, dim=0))[0] sampled_neg_inds = torch.where(torch.cat(sampled_neg_inds, dim=0))[0] sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0) @@ -145,7 +134,6 @@ def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets) size_average=False, ) / (sampled_inds.numel())) - objectness_loss.append(F.binary_cross_entropy_with_logits( objectness_[sampled_inds].flatten(), labels_[sampled_inds] )) @@ -221,10 +209,8 @@ def fastrcnn_loss(class_logits, box_regression, labels, regression_targets): classification_loss (Tensor) box_loss (Tensor) """ - class_logits = torch.split(class_logits, 512,dim=0) box_regression = torch.split(box_regression, 512,dim=0) - classification_loss = [] box_loss = [] @@ -291,13 +277,11 @@ def forward(self, matched_idxs = None box_features = self.box_roi_pool(features, proposals, image_shapes) - box_features = self.box_head(box_features) class_logits, box_regression = self.box_predictor(box_features) result = torch.jit.annotate(List[Dict[str, torch.Tensor]], []) losses = {} - if self.training: assert labels is not None and regression_targets is not None @@ -435,10 +419,9 @@ def __init__(self, backbone, num_classes=None, transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std) super(FastWILDS, self).__init__(backbone, rpn, roi_heads, transform) + # Set your own forward pass def forward(self, images, targets=None): - - if self.training: if targets is None: raise ValueError("In training mode, targets should be passed") @@ -460,7 +443,6 @@ def forward(self, images, targets=None): assert len(val) == 2 original_image_sizes.append((val[0], val[1])) - images, targets = self.transform(images, targets) # Check for degenerate boxes @@ -482,11 +464,7 @@ def forward(self, images, targets=None): features = OrderedDict([('0', features)]) proposals, proposal_losses = self.rpn(images, features, targets) - - detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets) - - detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes) for idx, det in enumerate(detections): @@ -496,28 +474,18 @@ def forward(self, images, targets=None): for k,v in detector_losses.items(): det["losses"][k] = v[idx] - return detections - - - - - class FasterRCNNLoss(nn.Module): def __init__(self,device): self.device = device super().__init__() def forward(self, outputs, targets): - - # loss values are loss_classifier loss_box_reg loss_objectness": loss_objectness, loss_rpn_box_reg try: elementwise_loss = torch.stack([sum(v for v in item["losses"].values()) for item in outputs]) except: elementwise_loss = torch.ones(len(outputs)).to(self.device) - - return elementwise_loss diff --git a/wilds/__init__.py b/wilds/__init__.py index fe1bed0a..2072baba 100644 --- a/wilds/__init__.py +++ b/wilds/__init__.py @@ -10,7 +10,6 @@ 'poverty', 'fmow', 'py150', - 'gwhd', ] additional_datasets = [ @@ -19,6 +18,7 @@ 'yelp', 'bdd100k', 'sqf', + 'gwhd', ] supported_datasets = benchmark_datasets + additional_datasets diff --git a/wilds/datasets/gwhd_dataset.py b/wilds/datasets/gwhd_dataset.py index 432154b3..194f16ae 100644 --- a/wilds/datasets/gwhd_dataset.py +++ b/wilds/datasets/gwhd_dataset.py @@ -7,46 +7,23 @@ from wilds.common.grouper import CombinatorialGrouper from wilds.common.metrics.all_metrics import DetectionAccuracy -def decode_string(BoxesString): - """ - Small method to decode the BoxesString - """ - if BoxesString == "no_box": - return np.zeros((0,4)) - else: - try: - boxes = np.array([np.array([int(i) for i in box.split(" ")]) - for box in BoxesString.split(";")]) - return boxes - except: - print(BoxesString) - print("Submission is not well formatted. empty boxes will be returned") - return np.zeros((0,4)) -def _collate_fn(batch): - """ - Stack x (batch[0]) and metadata (batch[2]), but not y. - originally, batch = (item1, item2, item3, item4) - after zip, batch = [(item1[0], item2[0], ..), ..] - """ - batch = list(zip(*batch)) - batch[0] = torch.stack(batch[0]) - batch[1] = list(batch[1]) - batch[2] = torch.stack(batch[2]) - return tuple(batch) - class GWHDDataset(WILDSDataset): """ - The GWHD-wilds wheat head localization dataset. + The GWHD-WILDS wheat head localization dataset. This is a modified version of the original Global Wheat Head Dataset 2021. - This dataset is not part of the official WILDS benchmark. - We provide it for convenience and to reproduce observations discussed in the WILDS paper. + + The current version does not contain test or validation labels, as it is being used in a + currently-running competition. + After the competition concludes in July 2021, we will update the dataset to contain the + final splits with test and validation labels, and add the dataset to the official WILDS + benchmark. + Supported `split_scheme`: - - 'official' for WILDS related tasks. - - 'in-dist' and 'ood_with_subsampled_test' to reproduce the baseline described in the paper. WARNING: these splits are not accessible before v1.0 + - 'official' Input (x): - 1024x1024 RGB images of wheat field canopy starting from anthesis (flowering) to ripening. + 1024 x 1024 RGB images of wheat field canopy starting from anthesis (flowering) to ripening. Output (y): - y is a nx4-dimensional vector where each line represents a box coordinate (x_min, y_min, x_max, y_max) + y is a n x 4-dimensional vector where each line represents a box coordinate (x_min, y_min, x_max, y_max) Metadata: Each image is annotated with the ID of the domain (location_date_sensor) it came from (integer from 0 to 46). Website: @@ -59,7 +36,7 @@ class GWHDDataset(WILDSDataset): doi = {10.34133/2020/3521852}, journal = {Plant Phenomics}, author = {David, Etienne and Madec, Simon and Sadeghi-Tehran, Pouria and Aasen, Helge and Zheng, Bangyou and Liu, Shouyang and Kirchgessner, Norbert and Ishikawa, Goro and Nagasawa, Koichi and Badhon, Minhajul A. and Pozniak, Curtis and de Solan, Benoit and Hund, Andreas and Chapman, Scott C. and Baret, Frédéric and Stavness, Ian and Guo, Wei}, - month = aug, + month = Aug, year = {2020}, note = {Publisher: AAAS}, pages = {3521852}, @@ -69,9 +46,14 @@ class GWHDDataset(WILDSDataset): """ _dataset_name = 'gwhd' - - # Version 0.9 corresponds to the final dataset, but without the test label. It can be used to train - # a model but no validation nor test metrics are available before 5th July 2021 + + # Version 0.9 corresponds to the final dataset, but without the validation and test labels, + # since it is being used in a currently-running competition (http://www.global-wheat.com/). + # Users can submit their val+test predictions to the competition to obtain an estimate of + # held-out performance computed on a fraction of those predictions; + # please see the tutorial at https://www.aicrowd.com/challenges/global-wheat-challenge-2021. + # We will update the dataset to include these labels and update the splits after the + # competition ends in July 2021. _versions_dict = { '0.9': { 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x8ba9122a41454997afdfb78762d390cf/contents/blob/', @@ -91,8 +73,7 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' self._split_scheme = split_scheme # Get filenames - - if split_scheme =="official": + if split_scheme == "official": train_data_df = pd.read_csv(self.root / f'official_train.csv') val_data_df = pd.read_csv(self.root / f'official_val.csv') test_data_df = pd.read_csv(self.root / f'official_test.csv') @@ -113,18 +94,15 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' val_data_df = pd.read_csv(self.root / f'official_val.csv') test_data_df = pd.read_csv(self.root / f'in-dist_test.csv') - self._image_array = [] self._split_array, self._y_array, self._metadata_array = [], [], [] for i, df in enumerate([train_data_df, val_data_df, test_data_df]): self._image_array.extend(list(df['image_name'].values)) boxes_string = list(df['BoxesString'].values) - all_boxes = [decode_string(box_string) for box_string in boxes_string] - - + all_boxes = [GWHDDataset._decode_string(box_string) for box_string in boxes_string] self._split_array.extend([i] * len(all_boxes)) - + labels = [{ "boxes": torch.stack([ torch.tensor(box) @@ -140,17 +118,14 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' self._metadata_array.extend(list(df['domain'].values)) self._split_array = np.array(self._split_array) - self._metadata_array = torch.tensor(self._metadata_array, dtype=torch.long).unsqueeze(1) self._metadata_fields = ['location_date_sensor'] - self._eval_grouper = CombinatorialGrouper( dataset=self, groupby_fields=['location_date_sensor']) - - self._metric = DetectionAccuracy() - self._collate = _collate_fn + self._metric = DetectionAccuracy() + self._collate = GWHDDataset._collate_fn super().__init__(root_dir, download, split_scheme) @@ -167,3 +142,33 @@ def eval(self, y_pred, y_true, metadata): self._metric, self._eval_grouper, y_pred, y_true, metadata) + + @staticmethod + def _decode_string(box_string): + """ + Helper method to decode each box_string in the BoxesString field of the data CSVs + """ + if boxes_string == "no_box": + return np.zeros((0,4)) + else: + try: + boxes = np.array([np.array([int(i) for i in box.split(" ")]) + for box in boxes_string.split(";")]) + return boxes + except: + print(boxes_string) + print("Submission is not well formatted. empty boxes will be returned") + return np.zeros((0,4)) + + @staticmethod + def _collate_fn(batch): + """ + Stack x (batch[0]) and metadata (batch[2]), but not y. + originally, batch = (item1, item2, item3, item4) + after zip, batch = [(item1[0], item2[0], ..), ..] + """ + batch = list(zip(*batch)) + batch[0] = torch.stack(batch[0]) + batch[1] = list(batch[1]) + batch[2] = torch.stack(batch[2]) + return tuple(batch) From 0d45610a2a6e5eee4e74cd4887e74d8e37fbe769 Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Sun, 9 May 2021 17:04:01 -0700 Subject: [PATCH 177/244] box_string typo --- wilds/datasets/gwhd_dataset.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/wilds/datasets/gwhd_dataset.py b/wilds/datasets/gwhd_dataset.py index 194f16ae..bb84e91d 100644 --- a/wilds/datasets/gwhd_dataset.py +++ b/wilds/datasets/gwhd_dataset.py @@ -57,7 +57,7 @@ class GWHDDataset(WILDSDataset): _versions_dict = { '0.9': { 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x8ba9122a41454997afdfb78762d390cf/contents/blob/', - 'compressed_size': None}} + 'compressed_size': 10_280_247_296}} def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'): @@ -148,15 +148,15 @@ def _decode_string(box_string): """ Helper method to decode each box_string in the BoxesString field of the data CSVs """ - if boxes_string == "no_box": + if box_string == "no_box": return np.zeros((0,4)) else: try: boxes = np.array([np.array([int(i) for i in box.split(" ")]) - for box in boxes_string.split(";")]) + for box in box_string.split(";")]) return boxes except: - print(boxes_string) + print(box_string) print("Submission is not well formatted. empty boxes will be returned") return np.zeros((0,4)) From eb4cef8b532a1284c34f2aabe73253d53b5bf1d1 Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Mon, 10 May 2021 23:46:44 -0700 Subject: [PATCH 178/244] new in-dist split --- wilds/datasets/rxrx1_dataset.py | 47 +++++++++++++++++++++++++++++---- 1 file changed, 42 insertions(+), 5 deletions(-) diff --git a/wilds/datasets/rxrx1_dataset.py b/wilds/datasets/rxrx1_dataset.py index 5ba7c0e3..7a9a5a59 100644 --- a/wilds/datasets/rxrx1_dataset.py +++ b/wilds/datasets/rxrx1_dataset.py @@ -1,5 +1,6 @@ import os from pathlib import Path +from collections import defaultdict from PIL import Image import pandas as pd @@ -72,16 +73,52 @@ def __init__(self, version=None, root_dir='data', download=False, # Splits if split_scheme == 'official': - self._split_dict = {'train': 0, 'val': 1, 'test': 2} - self._split_names = {'train': 'Train', 'val': 'Validation', 'test': 'Test'} - self._split_array = df.dataset.apply(self._split_dict.get).values - elif split_scheme == 'in-dist': + # Training: 33 experiments, 1 site per experiment (site 1) + # Validation: 4 experiments, 2 sites per experiment + # Test: 14 experiments, 2 sites per experiment self._split_dict = {'train': 0, 'val': 1, 'test': 2, 'id-test': 3} self._split_names = {'train': 'Train', 'val': 'Validation', 'test': 'Test', 'id-test': 'In-Distribution Test'} self._split_array = df.dataset.apply(self._split_dict.get).values # id-test set - mask = ((df.dataset == "train") & (df.site == 2)).values + mask = ((df.dataset == 'train') & (df.site == 2)).values self._split_array = np.where(mask, 3, self._split_array) + # TODO: Split in-dist test and val? + elif split_scheme == 'in-dist': + # Training: 33 experiments total, 1 site per experiment (site 1) + # = 19 experiments from the original training set (site 1) + # + 14 experiments from the original test set (site 1) + # Validation: same + # Test: 14 experiments from the original test set, 1 site per experiment (site 2) + self._split_dict = {'train': 0, 'val': 1, 'test': 2} + self._split_names = {'train': 'Train', 'val': 'Validation', 'test': 'Test'} + self._split_array = df.dataset.apply(self._split_dict.get).values + # Use half of the training set (site 1) and discard site 2 + mask_to_discard = ((df.dataset == 'train') & (df.site == 2)).values + self._split_array[mask_to_discard] = -1 + # Take all site 1 in the test set and move it to train + mask_to_move = ((df.dataset == 'test') & (df.site == 1)).values + self._split_array[mask_to_move] = self._split_dict['train'] + # For each of the test experiments, remove a train experiment of the same cell type + test_cell_type_counts = defaultdict(int) + test_experiments = df.loc[(df['dataset'] == 'test'), 'experiment'].unique() + for test_experiment in test_experiments: + test_cell_type = test_experiment.split('-')[0] + test_cell_type_counts[test_cell_type] += 1 + # Training experiments are numbered starting from 1 and left-padded with 0s + experiments_to_discard = [ + f'{cell_type}-{num:02}' + for cell_type, count in test_cell_type_counts.items() + for num in range(1, count+1)] + # Sanity check + train_experiments = df.loc[(df['dataset'] == 'train'), 'experiment'].unique() + for experiment in experiments_to_discard: + assert experiment in train_experiments + mask_to_discard = (df.experiment == experiment).values + self._split_array[mask_to_discard] = -1 + # import IPython + # IPython.embed() + else: + raise ValueError(f'Split scheme {self._split_scheme} not recognized') # Filenames def create_filepath(row): From 8506cdd9cb79b2d9ceda5c6accee5aef250b4279 Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Wed, 12 May 2021 15:37:21 -0700 Subject: [PATCH 179/244] replace filtering with nansum instead of sum --- examples/configs/datasets.py | 6 +++--- wilds/datasets/encodetfbs_dataset.py | 20 +++++++++----------- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index 5168333f..d7e7858d 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -118,11 +118,11 @@ 'val_metric_decreasing': False, 'optimizer': 'Adam', 'scheduler': 'MultiStepLR', - 'scheduler_kwargs': {'milestones':[3,6], 'gamma': 0.1}, - 'batch_size': 128, + 'scheduler_kwargs': {'milestones':[24,36], 'gamma': 0.1}, # used to be 6, 9, with 12 epochs + 'batch_size': 256, 'lr': 1e-3, 'weight_decay': 1e-4, - 'n_epochs': 9, + 'n_epochs': 48, 'n_groups_per_batch': 2, 'algo_log_metric': 'multitask_binary_accuracy', 'irm_lambda': 100.0, diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index 3cac8b48..9d938f12 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -138,8 +138,8 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' val_chroms = ['chr2', 'chr9', 'chr11'] test_chroms = ['chr1', 'chr8', 'chr21'] official_train_cts = { - 'MAX': ['H1-hESC', 'HCT116', 'HeLa-S3', 'K562', 'A549', 'GM12878'], - 'REST': ['H1-hESC', 'HeLa-S3', 'MCF-7', 'Panc1'], + 'MAX': ['H1-hESC', 'HCT116', 'HeLa-S3', 'K562', 'A549', 'GM12878'], + 'REST': ['H1-hESC', 'HeLa-S3', 'MCF-7', 'Panc1'], 'JUND': ['HCT116', 'HeLa-S3', 'K562', 'MCF-7'] } official_val_cts = { @@ -148,7 +148,7 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' official_test_cts = { 'MAX': ['liver'], 'REST': ['liver'], 'JUND': ['liver'] } - + # Set the TF in split_scheme by prefacing it with 'tf..' self._transcription_factor = 'MAX' if 'tf.' in split_scheme: @@ -156,11 +156,11 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' self._transcription_factor = tkns[1] split_scheme = '.'.join(tkns[2:]) self._split_scheme = split_scheme - + train_celltypes = official_train_cts[self._transcription_factor] val_celltype = official_val_cts[self._transcription_factor] test_celltype = official_test_cts[self._transcription_factor] - + if self._split_scheme == 'official': splits = { 'train': { @@ -319,12 +319,10 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' self._split_array[chrom_mask & celltype_mask] = self._split_dict[split] keep_mask = (self._split_array != -1) - # Remove all-zero sequences from training. - remove_allnegative = True - if remove_allnegative: - train_mask = (self._split_array == self._split_dict['train']) - allzeroes_mask = (self._y_array.sum(axis=1) == 0).numpy() - keep_mask = keep_mask & ~(train_mask & allzeroes_mask) + # Remove all-zero sequences from training. + train_mask = (self._split_array == self._split_dict['train']) + allzeroes_mask = (self._y_array.nansum(axis=1) == 0).numpy() + keep_mask = keep_mask & ~(train_mask & allzeroes_mask) # Subsample the testing and validation indices, to speed up evaluation. # For the OOD splits (val and test), we subsample by a factor of 3 From 2b87cfe085e0d84c03522d23f376faea2c58d31c Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Wed, 12 May 2021 17:00:12 -0700 Subject: [PATCH 180/244] rxrx1 cleanup --- wilds/datasets/camelyon17_dataset.py | 4 +-- wilds/datasets/rxrx1_dataset.py | 50 +++++++++++++++++++--------- 2 files changed, 36 insertions(+), 18 deletions(-) diff --git a/wilds/datasets/camelyon17_dataset.py b/wilds/datasets/camelyon17_dataset.py index 2efeaa41..ff6e6c63 100644 --- a/wilds/datasets/camelyon17_dataset.py +++ b/wilds/datasets/camelyon17_dataset.py @@ -9,7 +9,7 @@ class Camelyon17Dataset(WILDSDataset): """ - The CAMELYON17-wilds histopathology dataset. + The CAMELYON17-WILDS histopathology dataset. This is a modified version of the original CAMELYON17 dataset. Supported `split_scheme`: @@ -144,7 +144,7 @@ def eval(self, y_pred, y_true, metadata, prediction_fn=None): are predicted labels. - y_true (LongTensor): Ground-truth labels - metadata (Tensor): Metadata - - prediction_fn (function): A function that turns y_pred into predicted labels + - prediction_fn (function): A function that turns y_pred into predicted labels Output: - results (dictionary): Dictionary of evaluation metrics - results_str (str): String summarizing the evaluation metrics diff --git a/wilds/datasets/rxrx1_dataset.py b/wilds/datasets/rxrx1_dataset.py index 7a9a5a59..5f8e9aeb 100644 --- a/wilds/datasets/rxrx1_dataset.py +++ b/wilds/datasets/rxrx1_dataset.py @@ -47,8 +47,7 @@ class RxRx1Dataset(WILDSDataset): This work is licensed under a Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License. To view a copy of this license, visit - http://creativecommons.org/licenses/by-nc-sa/4.0/ or send a letter to - Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. + http://creativecommons.org/licenses/by-nc-sa/4.0/. """ _dataset_name = 'rxrx1' _versions_dict = { @@ -75,22 +74,43 @@ def __init__(self, version=None, root_dir='data', download=False, if split_scheme == 'official': # Training: 33 experiments, 1 site per experiment (site 1) # Validation: 4 experiments, 2 sites per experiment - # Test: 14 experiments, 2 sites per experiment - self._split_dict = {'train': 0, 'val': 1, 'test': 2, 'id-test': 3} - self._split_names = {'train': 'Train', 'val': 'Validation', 'test': 'Test', 'id-test': 'In-Distribution Test'} + # Test OOD: 14 experiments, 2 sites per experiment + # Test ID: Same 33 experiments from training set + # 1 site per experiment (site 2) + self._split_dict = { + 'train': 0, + 'val': 1, + 'test': 2, + 'id_test': 3 + } + self._split_names = { + 'train': 'Train', + 'val': 'Validation (OOD)', + 'test': 'Test (OOD)', + 'id-test': 'Test (ID)' + } self._split_array = df.dataset.apply(self._split_dict.get).values - # id-test set + # id_test set mask = ((df.dataset == 'train') & (df.site == 2)).values - self._split_array = np.where(mask, 3, self._split_array) - # TODO: Split in-dist test and val? + self._split_array[mask] = self.split_dict['id_test'] + elif split_scheme == 'in-dist': # Training: 33 experiments total, 1 site per experiment (site 1) - # = 19 experiments from the original training set (site 1) - # + 14 experiments from the original test set (site 1) - # Validation: same - # Test: 14 experiments from the original test set, 1 site per experiment (site 2) - self._split_dict = {'train': 0, 'val': 1, 'test': 2} - self._split_names = {'train': 'Train', 'val': 'Validation', 'test': 'Test'} + # = 19 experiments from the orig training set (site 1) + # + 14 experiments from the orig test set (site 1) + # Validation: same as official split + # Test: 14 experiments from the orig test set, + # 1 site per experiment (site 2) + self._split_dict = { + 'train': 0, + 'val': 1, + 'test': 2 + } + self._split_names = { + 'train': 'Train', + 'val': 'Validation', + 'test': 'Test' + } self._split_array = df.dataset.apply(self._split_dict.get).values # Use half of the training set (site 1) and discard site 2 mask_to_discard = ((df.dataset == 'train') & (df.site == 2)).values @@ -115,8 +135,6 @@ def __init__(self, version=None, root_dir='data', download=False, assert experiment in train_experiments mask_to_discard = (df.experiment == experiment).values self._split_array[mask_to_discard] = -1 - # import IPython - # IPython.embed() else: raise ValueError(f'Split scheme {self._split_scheme} not recognized') From 41bcbc9faed059b4b04c5a11e3fcb057591bcace Mon Sep 17 00:00:00 2001 From: aikanor Date: Wed, 12 May 2021 17:06:41 -0700 Subject: [PATCH 181/244] Fix nansum bug in training data filtering --- examples/models/CNN_genome.py | 4 ++-- wilds/datasets/encodetfbs_dataset.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/models/CNN_genome.py b/examples/models/CNN_genome.py index 6d3f7d0d..4e8d70a3 100644 --- a/examples/models/CNN_genome.py +++ b/examples/models/CNN_genome.py @@ -5,7 +5,7 @@ import torch.nn.functional as F -def single_conv(in_channels, out_channels, kernel_size=7): +def single_conv(in_channels, out_channels, kernel_size=25): padding_size = int((kernel_size-1)/2) return nn.Sequential( nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding_size), @@ -13,7 +13,7 @@ def single_conv(in_channels, out_channels, kernel_size=7): nn.ReLU(inplace=True) ) -def double_conv(in_channels, out_channels, kernel_size=7): +def double_conv(in_channels, out_channels, kernel_size=25): padding_size = int((kernel_size-1)/2) return nn.Sequential( nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding_size), diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index 3cac8b48..4ec06033 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -323,7 +323,7 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' remove_allnegative = True if remove_allnegative: train_mask = (self._split_array == self._split_dict['train']) - allzeroes_mask = (self._y_array.sum(axis=1) == 0).numpy() + allzeroes_mask = (self._y_array.nansum(axis=1) == 0).numpy() keep_mask = keep_mask & ~(train_mask & allzeroes_mask) # Subsample the testing and validation indices, to speed up evaluation. From a6f11ae1698f0e06b5d4b5f625dfb08a2efbf096 Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Wed, 12 May 2021 21:38:10 -0700 Subject: [PATCH 182/244] docstring --- wilds/datasets/rxrx1_dataset.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/wilds/datasets/rxrx1_dataset.py b/wilds/datasets/rxrx1_dataset.py index 5f8e9aeb..232cb0b6 100644 --- a/wilds/datasets/rxrx1_dataset.py +++ b/wilds/datasets/rxrx1_dataset.py @@ -14,9 +14,12 @@ class RxRx1Dataset(WILDSDataset): """ - The RxRx1 Dataset. + The RxRx1-WILDS dataset. This is a modified version of the original RxRx1 dataset. + Supported `split_scheme`: + 'official' or 'in-dist' + Input (x): 3-channel fluorescent microscopy images of cells @@ -208,9 +211,7 @@ def get_input(self, idx): Output: - x (Tensor): Input features of the idx-th data point """ - # All images are in the train folder img_path = self.data_dir / self._input_array[idx] img = Image.open(img_path) - return img From 2f661a991ecd85abeffa3a17d1345fd4e63b8ab6 Mon Sep 17 00:00:00 2001 From: Tony Lee Date: Thu, 13 May 2021 22:09:37 -0700 Subject: [PATCH 183/244] Fix split names --- wilds/datasets/rxrx1_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wilds/datasets/rxrx1_dataset.py b/wilds/datasets/rxrx1_dataset.py index 232cb0b6..9d93e3f5 100644 --- a/wilds/datasets/rxrx1_dataset.py +++ b/wilds/datasets/rxrx1_dataset.py @@ -90,7 +90,7 @@ def __init__(self, version=None, root_dir='data', download=False, 'train': 'Train', 'val': 'Validation (OOD)', 'test': 'Test (OOD)', - 'id-test': 'Test (ID)' + 'id_test': 'Test (ID)' } self._split_array = df.dataset.apply(self._split_dict.get).values # id_test set From 55107a75283d436e5a1652da7891a841d9c37c6e Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Fri, 14 May 2021 09:36:20 -0700 Subject: [PATCH 184/244] fix rxrx1 transform --- examples/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/transforms.py b/examples/transforms.py index 7c612bdf..dec68c2d 100644 --- a/examples/transforms.py +++ b/examples/transforms.py @@ -148,4 +148,4 @@ def random_d8(x: torch.Tensor) -> torch.Tensor: t_standardize, ] transform = transforms.Compose(transforms_ls) - return transform + return transform_input_only(transform) From 5e7f90553dc71692c9d8cda707937342cb6f5880 Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Fri, 14 May 2021 16:42:02 -0700 Subject: [PATCH 185/244] test_id subsampling --- examples/configs/datasets.py | 6 ++++-- wilds/datasets/encodetfbs_dataset.py | 13 ++++++++----- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index d7e7858d..671219a7 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -118,11 +118,13 @@ 'val_metric_decreasing': False, 'optimizer': 'Adam', 'scheduler': 'MultiStepLR', - 'scheduler_kwargs': {'milestones':[24,36], 'gamma': 0.1}, # used to be 6, 9, with 12 epochs + 'scheduler_kwargs': {'milestones':[6,9], 'gamma': 0.1}, # used to be 6, 9, with 12 epochs + # 'scheduler': 'linear_schedule_with_warmup', + # 'scheduler_kwargs': {'num_warmup_steps': 800}, # about 160 minibatches per epoch 'batch_size': 256, 'lr': 1e-3, 'weight_decay': 1e-4, - 'n_epochs': 48, + 'n_epochs': 12, 'n_groups_per_batch': 2, 'algo_log_metric': 'multitask_binary_accuracy', 'irm_lambda': 100.0, diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index 90a40da3..c3b7bc09 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -319,7 +319,7 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' self._split_array[chrom_mask & celltype_mask] = self._split_dict[split] keep_mask = (self._split_array != -1) - + # Remove all-zero sequences from training. train_mask = (self._split_array == self._split_dict['train']) allzeroes_mask = (self._y_array.nansum(axis=1) == 0).numpy() @@ -327,9 +327,12 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' # Subsample the testing and validation indices, to speed up evaluation. # For the OOD splits (val and test), we subsample by a factor of 3 - # For the id_val split if it exists, we subsample by a factor of 3*(# of training celltypes) - for subsample_seed, (split, subsample_factor) in enumerate( - [('val', 3), ('test', 3), ('id_val', 3*len(splits['train']['celltypes'])) ]): + # For the id_val and id_test splits, we subsample by a factor of 3*(# of training celltypes) + for subsample_seed, (split, subsample_factor) in enumerate([ + ('val', 3), + ('test', 3), + ('id_val', 3*len(splits['train']['celltypes'])), + ('id_test', 3*len(splits['train']['celltypes']))]): if split not in self._split_dict: continue split_mask = (self._split_array == self._split_dict[split]) split_idxs = np.arange(len(self._split_array))[split_mask] @@ -342,7 +345,7 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' self._metadata_df = self._metadata_df[keep_mask] self._split_array = self._split_array[keep_mask] - self._y_array = self._y_array[keep_mask] + self._y_array = self._y_array[keep_mask] self._all_chroms = sorted(list({chrom for _, d in splits.items() for chrom in d['chroms']})) self._all_celltypes = sorted(list({chrom for _, d in splits.items() for chrom in d['celltypes']})) From 7f0a6818e0df8801b074aedad8b1e9d219886798 Mon Sep 17 00:00:00 2001 From: aikanor Date: Sun, 16 May 2021 01:58:49 -0700 Subject: [PATCH 186/244] small kernel by default --- examples/models/CNN_genome.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/models/CNN_genome.py b/examples/models/CNN_genome.py index 4e8d70a3..6d3f7d0d 100644 --- a/examples/models/CNN_genome.py +++ b/examples/models/CNN_genome.py @@ -5,7 +5,7 @@ import torch.nn.functional as F -def single_conv(in_channels, out_channels, kernel_size=25): +def single_conv(in_channels, out_channels, kernel_size=7): padding_size = int((kernel_size-1)/2) return nn.Sequential( nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding_size), @@ -13,7 +13,7 @@ def single_conv(in_channels, out_channels, kernel_size=25): nn.ReLU(inplace=True) ) -def double_conv(in_channels, out_channels, kernel_size=25): +def double_conv(in_channels, out_channels, kernel_size=7): padding_size = int((kernel_size-1)/2) return nn.Sequential( nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding_size), From 4d63cc28fccd7fec5310d62768906f4e30bf030b Mon Sep 17 00:00:00 2001 From: Etienne David Date: Sun, 16 May 2021 20:55:49 +0200 Subject: [PATCH 187/244] correction for fasterrcnn by changin the default parameter in examples/configs/model --- examples/configs/datasets.py | 1 + examples/configs/model.py | 2 ++ examples/models/initializer.py | 8 +++++++- 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index 47030ebb..1f91b85d 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -307,6 +307,7 @@ 'pin_memory': True, }, 'process_outputs_function': None, + } } diff --git a/examples/configs/model.py b/examples/configs/model.py index 582e2d6c..ee2098a8 100644 --- a/examples/configs/model.py +++ b/examples/configs/model.py @@ -41,6 +41,8 @@ 'model_kwargs': { 'pretrained_model': True, 'pretrained_backbone': True, + 'min_size' :1024, + 'max_size' :1024 } } } diff --git a/examples/models/initializer.py b/examples/models/initializer.py index b9a9bb7e..16bbe69a 100644 --- a/examples/models/initializer.py +++ b/examples/models/initializer.py @@ -162,6 +162,12 @@ def initialize_fasterrcnn_model(config, d_out): from models.detection.fasterrcnn import fasterrcnn_resnet50_fpn # load a model pre-trained pre-trained on COCO - model = fasterrcnn_resnet50_fpn(pretrained=config.model_kwargs["pretrained_model"],pretrained_backbone=config.model_kwargs["pretrained_backbone"],num_classes=d_out) + model = fasterrcnn_resnet50_fpn( + pretrained=config.model_kwargs["pretrained_model"], + pretrained_backbone=config.model_kwargs["pretrained_backbone"], + num_classes=d_out, + min_size=config.model_kwargs["min_size"], + max_size=config.model_kwargs["max_size"] + ) return model From b367f764ce9251a46284c6ec00cdbcb13306a244 Mon Sep 17 00:00:00 2001 From: kohpangwei Date: Wed, 19 May 2021 17:21:06 -0700 Subject: [PATCH 188/244] Remove path --- examples/run_expt.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/run_expt.py b/examples/run_expt.py index 19f6d570..ff0ba47d 100644 --- a/examples/run_expt.py +++ b/examples/run_expt.py @@ -8,7 +8,6 @@ import sys from collections import defaultdict -sys.path.insert(1, os.path.join(sys.path[0], '..')) import wilds from wilds.common.data_loaders import get_train_loader, get_eval_loader from wilds.common.grouper import CombinatorialGrouper From 2dd6fd6e9f00992b5ea19bbf16d93a2f4f010886 Mon Sep 17 00:00:00 2001 From: Tony Lee Date: Fri, 21 May 2021 00:05:20 -0700 Subject: [PATCH 189/244] update default hyperparameters for rxrx1 --- examples/configs/datasets.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index c07fa224..9152ebdc 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -286,9 +286,12 @@ 'optimizer_kwargs': {}, 'scheduler': 'cosine_schedule_with_warmup', 'scheduler_kwargs': {'num_warmup_steps': 5415}, - 'batch_size': 75, - 'lr': 1e-4, + 'batch_size': 72, + 'lr': 1e-3, 'weight_decay': 1e-5, + 'n_groups_per_batch': 9, + 'coral_penalty_weight': 0.1, + 'irm_lambda': 1.0, 'n_epochs': 90, 'process_outputs_function': 'multiclass_logits_to_pred', }, From abcb48e65fdc40e6e5271fb3372e4383cd31fa31 Mon Sep 17 00:00:00 2001 From: Berton Earnshaw Date: Tue, 25 May 2021 10:10:16 -0600 Subject: [PATCH 190/244] Add rxrx1 reference --- wilds/datasets/rxrx1_dataset.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/wilds/datasets/rxrx1_dataset.py b/wilds/datasets/rxrx1_dataset.py index 9d93e3f5..bc728dce 100644 --- a/wilds/datasets/rxrx1_dataset.py +++ b/wilds/datasets/rxrx1_dataset.py @@ -37,13 +37,14 @@ class RxRx1Dataset(WILDSDataset): https://www.rxrx.ai/rxrx1 https://www.kaggle.com/c/recursion-cellular-image-classification - FIXME Original publication: - @article{, - title={}, - author={}, - journal={}, - year={} + @inproceedings{taylor2019rxrx1, + author = {Taylor, J. and Earnshaw, B. and Mabey, B. and Victors, M. and Yosinski, J.}, + title = {RxRx1: An Image Set for Cellular Morphological Variation Across Many Experimental Batches.}, + year = {2019}, + booktitle = {International Conference on Learning Representations (ICLR)}, + booksubtitle = {AI for Social Good Workshop}, + url = {https://aiforsocialgood.github.io/iclr2019/accepted/track1/pdfs/30_aisg_iclr2019.pdf}, } License: From 9ab1f3495592d6c533fecc747a32701feb46ad67 Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Tue, 25 May 2021 22:51:04 -0700 Subject: [PATCH 191/244] globalwheat name change + changed default metric to dom avg --- examples/configs/datasets.py | 8 +- examples/run_expt.py | 5 +- examples/train.py | 1 + wilds/__init__.py | 2 +- wilds/common/utils.py | 5 +- wilds/datasets/gwhd_dataset.py | 174 --------------------------------- wilds/get_dataset.py | 6 +- 7 files changed, 14 insertions(+), 187 deletions(-) delete mode 100644 wilds/datasets/gwhd_dataset.py diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index 1f91b85d..ede92f97 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -282,7 +282,7 @@ 'n_epochs': 4, 'process_outputs_function': None, }, - 'gwhd': { + 'globalwheat': { 'split_scheme': 'official', 'model': 'fasterrcnn', 'train_transform': 'image_base', @@ -292,7 +292,7 @@ 'pretrained': True}, 'loss_function': 'fasterrcnn_criterion', 'groupby_fields': ['location_date_sensor'], - 'val_metric': 'detection_acc_avg', # TODO + 'val_metric': 'detection_acc_avg_dom', 'val_metric_decreasing': False, 'algo_log_metric': None, # TODO 'optimizer': 'Adam', @@ -306,8 +306,8 @@ 'num_workers': 1, 'pin_memory': True, }, - 'process_outputs_function': None, - + 'process_outputs_function': None, + } } diff --git a/examples/run_expt.py b/examples/run_expt.py index ff0ba47d..8d8343b7 100644 --- a/examples/run_expt.py +++ b/examples/run_expt.py @@ -117,9 +117,10 @@ def main(): config = parser.parse_args() config = populate_defaults(config) - # For the GWHD dataset, we need to change the multiprocessing strategy or there will be + # For the GlobalWheat detection dataset, + # we need to change the multiprocessing strategy or there will be # too many open file descriptors - if config.dataset == 'gwhd': + if config.dataset == 'globalwheat': torch.multiprocessing.set_sharing_strategy('file_system') # Set device diff --git a/examples/train.py b/examples/train.py index 93cc076e..37184b61 100644 --- a/examples/train.py +++ b/examples/train.py @@ -5,6 +5,7 @@ from configs.supported import process_outputs_functions def run_epoch(algorithm, dataset, general_logger, epoch, config, train): + if dataset['verbose']: general_logger.write(f"\n{dataset['name']}:\n") diff --git a/wilds/__init__.py b/wilds/__init__.py index 2072baba..6d406f2a 100644 --- a/wilds/__init__.py +++ b/wilds/__init__.py @@ -18,7 +18,7 @@ 'yelp', 'bdd100k', 'sqf', - 'gwhd', + 'globalwheat', ] supported_datasets = benchmark_datasets + additional_datasets diff --git a/wilds/common/utils.py b/wilds/common/utils.py index ebf2f41d..b8560f6b 100644 --- a/wilds/common/utils.py +++ b/wilds/common/utils.py @@ -1,4 +1,4 @@ -import torch, torch_scatter +import torch import numpy as np from torch.utils.data import Subset from pandas.api.types import CategoricalDtype @@ -81,8 +81,7 @@ def avg_over_groups(v, g, n_groups): group_avgs (Tensor): Vector of length num_groups group_counts (Tensor) """ - - + import torch_scatter assert v.device==g.device assert v.numel()==g.numel() group_count = get_counts(g, n_groups) diff --git a/wilds/datasets/gwhd_dataset.py b/wilds/datasets/gwhd_dataset.py deleted file mode 100644 index bb84e91d..00000000 --- a/wilds/datasets/gwhd_dataset.py +++ /dev/null @@ -1,174 +0,0 @@ -import numpy as np -import pandas as pd -import torch -from pathlib import Path -from PIL import Image -from wilds.datasets.wilds_dataset import WILDSDataset -from wilds.common.grouper import CombinatorialGrouper -from wilds.common.metrics.all_metrics import DetectionAccuracy - -class GWHDDataset(WILDSDataset): - """ - The GWHD-WILDS wheat head localization dataset. - This is a modified version of the original Global Wheat Head Dataset 2021. - - The current version does not contain test or validation labels, as it is being used in a - currently-running competition. - After the competition concludes in July 2021, we will update the dataset to contain the - final splits with test and validation labels, and add the dataset to the official WILDS - benchmark. - - Supported `split_scheme`: - - 'official' - Input (x): - 1024 x 1024 RGB images of wheat field canopy starting from anthesis (flowering) to ripening. - Output (y): - y is a n x 4-dimensional vector where each line represents a box coordinate (x_min, y_min, x_max, y_max) - Metadata: - Each image is annotated with the ID of the domain (location_date_sensor) it came from (integer from 0 to 46). - Website: - http://www.global-wheat.com/ - Original publication: - @article{david_global_2020, - title = {Global {Wheat} {Head} {Detection} ({GWHD}) {Dataset}: {A} {Large} and {Diverse} {Dataset} of {High}-{Resolution} {RGB}-{Labelled} {Images} to {Develop} and {Benchmark} {Wheat} {Head} {Detection} {Methods}}, - volume = {2020}, - url = {https://doi.org/10.34133/2020/3521852}, - doi = {10.34133/2020/3521852}, - journal = {Plant Phenomics}, - author = {David, Etienne and Madec, Simon and Sadeghi-Tehran, Pouria and Aasen, Helge and Zheng, Bangyou and Liu, Shouyang and Kirchgessner, Norbert and Ishikawa, Goro and Nagasawa, Koichi and Badhon, Minhajul A. and Pozniak, Curtis and de Solan, Benoit and Hund, Andreas and Chapman, Scott C. and Baret, Frédéric and Stavness, Ian and Guo, Wei}, - month = Aug, - year = {2020}, - note = {Publisher: AAAS}, - pages = {3521852}, - } - License: - This dataset is distributed under the MIT license. - """ - - _dataset_name = 'gwhd' - - # Version 0.9 corresponds to the final dataset, but without the validation and test labels, - # since it is being used in a currently-running competition (http://www.global-wheat.com/). - # Users can submit their val+test predictions to the competition to obtain an estimate of - # held-out performance computed on a fraction of those predictions; - # please see the tutorial at https://www.aicrowd.com/challenges/global-wheat-challenge-2021. - # We will update the dataset to include these labels and update the splits after the - # competition ends in July 2021. - _versions_dict = { - '0.9': { - 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x8ba9122a41454997afdfb78762d390cf/contents/blob/', - 'compressed_size': 10_280_247_296}} - - def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'): - - self._version = version - self._data_dir = self.initialize_data_dir(root_dir, download) - self._original_resolution = (1024, 1024) - self.root = Path(self.data_dir) - self._is_detection = True - self._is_classification = False - self._y_size = None - self._n_classes = 1 - - self._split_scheme = split_scheme - - # Get filenames - if split_scheme == "official": - train_data_df = pd.read_csv(self.root / f'official_train.csv') - val_data_df = pd.read_csv(self.root / f'official_val.csv') - test_data_df = pd.read_csv(self.root / f'official_test.csv') - - elif split_scheme == "ood_with_subsampled_test": - if version == "0.9": - print("Warning: ood_with_subsampled_test is not available in 0.9") - else: - train_data_df = pd.read_csv(self.root / f'official_train.csv') - val_data_df = pd.read_csv(self.root / f'official_val.csv') - test_data_df = pd.read_csv(self.root / f'in-dist_test.csv') - - elif split_scheme == "in-dist": - if version == "0.9": - print("Warning: ood_with_subsampled_test is not available in 0.9") - else: - train_data_df = pd.read_csv(self.root / f'in-dist_train.csv') - val_data_df = pd.read_csv(self.root / f'official_val.csv') - test_data_df = pd.read_csv(self.root / f'in-dist_test.csv') - - self._image_array = [] - self._split_array, self._y_array, self._metadata_array = [], [], [] - - for i, df in enumerate([train_data_df, val_data_df, test_data_df]): - self._image_array.extend(list(df['image_name'].values)) - boxes_string = list(df['BoxesString'].values) - all_boxes = [GWHDDataset._decode_string(box_string) for box_string in boxes_string] - self._split_array.extend([i] * len(all_boxes)) - - labels = [{ - "boxes": torch.stack([ - torch.tensor(box) - for box in boxes - ]), - "labels": torch.tensor([1]*len(boxes)).long() - } if len(boxes) > 0 else { - "boxes": torch.empty(0,4), - "labels": torch.empty(0,dtype=torch.long) - } for boxes in all_boxes] - - self._y_array.extend(labels) - self._metadata_array.extend(list(df['domain'].values)) - - self._split_array = np.array(self._split_array) - self._metadata_array = torch.tensor(self._metadata_array, - dtype=torch.long).unsqueeze(1) - self._metadata_fields = ['location_date_sensor'] - self._eval_grouper = CombinatorialGrouper( - dataset=self, - groupby_fields=['location_date_sensor']) - self._metric = DetectionAccuracy() - self._collate = GWHDDataset._collate_fn - - super().__init__(root_dir, download, split_scheme) - - def get_input(self, idx): - """ - Returns x for a given idx. - """ - img_filename = self.root / "images" / self._image_array[idx] - x = Image.open(img_filename) - return x - - def eval(self, y_pred, y_true, metadata): - return self.standard_group_eval( - self._metric, - self._eval_grouper, - y_pred, y_true, metadata) - - @staticmethod - def _decode_string(box_string): - """ - Helper method to decode each box_string in the BoxesString field of the data CSVs - """ - if box_string == "no_box": - return np.zeros((0,4)) - else: - try: - boxes = np.array([np.array([int(i) for i in box.split(" ")]) - for box in box_string.split(";")]) - return boxes - except: - print(box_string) - print("Submission is not well formatted. empty boxes will be returned") - return np.zeros((0,4)) - - @staticmethod - def _collate_fn(batch): - """ - Stack x (batch[0]) and metadata (batch[2]), but not y. - originally, batch = (item1, item2, item3, item4) - after zip, batch = [(item1[0], item2[0], ..), ..] - """ - batch = list(zip(*batch)) - batch[0] = torch.stack(batch[0]) - batch[1] = list(batch[1]) - batch[2] = torch.stack(batch[2]) - return tuple(batch) diff --git a/wilds/get_dataset.py b/wilds/get_dataset.py index cfa5f2c7..55c69c28 100644 --- a/wilds/get_dataset.py +++ b/wilds/get_dataset.py @@ -78,6 +78,6 @@ def get_dataset(dataset, version=None, **dataset_kwargs): from wilds.datasets.sqf_dataset import SQFDataset return SQFDataset(version=version, **dataset_kwargs) - elif dataset == 'gwhd': - from wilds.datasets.gwhd_dataset import GWHDDataset - return GWHDDataset(version=version, **dataset_kwargs) + elif dataset == 'globalwheat': + from wilds.datasets.globalwheat_dataset import GlobalWheatDataset + return GlobalWheatDataset(version=version, **dataset_kwargs) From fa229e9395ffec715ff2a995f792df8c1984216a Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Tue, 25 May 2021 22:52:00 -0700 Subject: [PATCH 192/244] add globalwheat_dataset --- wilds/datasets/globalwheat_dataset.py | 191 ++++++++++++++++++++++++++ 1 file changed, 191 insertions(+) create mode 100644 wilds/datasets/globalwheat_dataset.py diff --git a/wilds/datasets/globalwheat_dataset.py b/wilds/datasets/globalwheat_dataset.py new file mode 100644 index 00000000..d4683dc3 --- /dev/null +++ b/wilds/datasets/globalwheat_dataset.py @@ -0,0 +1,191 @@ +import numpy as np +import pandas as pd +import torch +from pathlib import Path +from PIL import Image +from wilds.datasets.wilds_dataset import WILDSDataset +from wilds.common.grouper import CombinatorialGrouper +from wilds.common.metrics.all_metrics import DetectionAccuracy + +class GlobalWheatDataset(WILDSDataset): + """ + The GlobalWheat-WILDS wheat head localization dataset. + This is a modified version of the original Global Wheat Head Dataset 2021. + + The current version does not contain test or validation labels, as it is being used in a + currently-running competition. + After the competition concludes in July 2021, we will update the dataset to contain the + final splits with test and validation labels, and add the dataset to the official WILDS + benchmark. + + Supported `split_scheme`: + - 'official' + Input (x): + 1024 x 1024 RGB images of wheat field canopy starting from anthesis (flowering) to ripening. + Output (y): + y is a n x 4-dimensional vector where each line represents a box coordinate (x_min, y_min, x_max, y_max) + Metadata: + Each image is annotated with the ID of the domain (location_date_sensor) it came from (integer from 0 to 46). + Website: + http://www.global-wheat.com/ + Original publication: + @article{david_global_2020, + title = {Global {Wheat} {Head} {Detection} ({GWHD}) {Dataset}: {A} {Large} and {Diverse} {Dataset} of {High}-{Resolution} {RGB}-{Labelled} {Images} to {Develop} and {Benchmark} {Wheat} {Head} {Detection} {Methods}}, + volume = {2020}, + url = {https://doi.org/10.34133/2020/3521852}, + doi = {10.34133/2020/3521852}, + journal = {Plant Phenomics}, + author = {David, Etienne and Madec, Simon and Sadeghi-Tehran, Pouria and Aasen, Helge and Zheng, Bangyou and Liu, Shouyang and Kirchgessner, Norbert and Ishikawa, Goro and Nagasawa, Koichi and Badhon, Minhajul A. and Pozniak, Curtis and de Solan, Benoit and Hund, Andreas and Chapman, Scott C. and Baret, Frédéric and Stavness, Ian and Guo, Wei}, + month = Aug, + year = {2020}, + note = {Publisher: AAAS}, + pages = {3521852}, + } + License: + This dataset is distributed under the MIT license. + """ + + _dataset_name = 'globalwheat' + + # Version 0.9 corresponds to the final dataset, but without the validation and test labels, + # since it is being used in a currently-running competition (http://www.global-wheat.com/). + # Users can submit their val+test predictions to the competition to obtain an estimate of + # held-out performance computed on a fraction of those predictions; + # please see the tutorial at https://www.aicrowd.com/challenges/global-wheat-challenge-2021. + # We will update the dataset to include these labels and update the splits after the + # competition ends in July 2021. + _versions_dict = { + '0.9': { + 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x8ba9122a41454997afdfb78762d390cf/contents/blob/', + 'compressed_size': 10_280_247_296}} + + def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'): + + self._version = version + self._data_dir = self.initialize_data_dir(root_dir, download) + self._original_resolution = (1024, 1024) + self.root = Path(self.data_dir) + self._is_detection = True + self._is_classification = False + self._y_size = None + self._n_classes = 1 + + self._split_scheme = split_scheme + + # Get filenames + if split_scheme == "official": + train_data_df = pd.read_csv(self.root / f'official_train.csv') + val_data_df = pd.read_csv(self.root / f'official_val.csv') + test_data_df = pd.read_csv(self.root / f'official_test.csv') + + elif split_scheme == "ood_with_subsampled_test": + if version == "0.9": + print("Warning: ood_with_subsampled_test is not available in 0.9") + else: + train_data_df = pd.read_csv(self.root / f'official_train.csv') + val_data_df = pd.read_csv(self.root / f'official_val.csv') + test_data_df = pd.read_csv(self.root / f'in-dist_test.csv') + + elif split_scheme == "in-dist": + if version == "0.9": + print("Warning: ood_with_subsampled_test is not available in 0.9") + else: + train_data_df = pd.read_csv(self.root / f'in-dist_train.csv') + val_data_df = pd.read_csv(self.root / f'official_val.csv') + test_data_df = pd.read_csv(self.root / f'in-dist_test.csv') + + self._image_array = [] + self._split_array, self._y_array, self._metadata_array = [], [], [] + + for i, df in enumerate([train_data_df, val_data_df, test_data_df]): + self._image_array.extend(list(df['image_name'].values)) + boxes_string = list(df['BoxesString'].values) + all_boxes = [GlobalWheatDataset._decode_string(box_string) for box_string in boxes_string] + self._split_array.extend([i] * len(all_boxes)) + + labels = [{ + "boxes": torch.stack([ + torch.tensor(box) + for box in boxes + ]), + "labels": torch.tensor([1]*len(boxes)).long() + } if len(boxes) > 0 else { + "boxes": torch.empty(0,4), + "labels": torch.empty(0,dtype=torch.long) + } for boxes in all_boxes] + + self._y_array.extend(labels) + self._metadata_array.extend(list(df['domain'].values)) + + self._split_array = np.array(self._split_array) + self._metadata_array = torch.tensor(self._metadata_array, + dtype=torch.long).unsqueeze(1) + self._metadata_fields = ['location_date_sensor'] + self._eval_grouper = CombinatorialGrouper( + dataset=self, + groupby_fields=['location_date_sensor']) + self._metric = DetectionAccuracy() + self._collate = GlobalWheatDataset._collate_fn + + super().__init__(root_dir, download, split_scheme) + + def get_input(self, idx): + """ + Returns x for a given idx. + """ + img_filename = self.root / "images" / self._image_array[idx] + x = Image.open(img_filename) + return x + + def eval(self, y_pred, y_true, metadata): + """ + The main evaluation metric, detection_acc_avg_dom, + measures the simple average of the detection accuracies + of each domain. + """ + results, results_str = self.standard_group_eval( + self._metric, + self._eval_grouper, + y_pred, y_true, metadata) + + detection_accs = [] + for k, v in results.items(): + if k.startswith('detection_acc_location_date_sensor:'): + d = k.split(':')[1] + count = results[f'count_location_date_sensor:{d}'] + if count > 0: + detection_accs.append(v) + detection_acc_avg_dom = np.array(detection_accs).mean() + results['detection_acc_avg_dom'] = detection_acc_avg_dom + results_str = f'Average detection_acc across domains: {detection_acc_avg_dom:.3f}\n' + results_str + return results, results_str + + @staticmethod + def _decode_string(box_string): + """ + Helper method to decode each box_string in the BoxesString field of the data CSVs + """ + if box_string == "no_box": + return np.zeros((0,4)) + else: + try: + boxes = np.array([np.array([int(i) for i in box.split(" ")]) + for box in box_string.split(";")]) + return boxes + except: + print(box_string) + print("Submission is not well formatted. empty boxes will be returned") + return np.zeros((0,4)) + + @staticmethod + def _collate_fn(batch): + """ + Stack x (batch[0]) and metadata (batch[2]), but not y. + originally, batch = (item1, item2, item3, item4) + after zip, batch = [(item1[0], item2[0], ..), ..] + """ + batch = list(zip(*batch)) + batch[0] = torch.stack(batch[0]) + batch[1] = list(batch[1]) + batch[2] = torch.stack(batch[2]) + return tuple(batch) From 89ef2c4a5942b0ea803b2069fc11bb70c8db5249 Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Tue, 25 May 2021 23:42:15 -0700 Subject: [PATCH 193/244] Add warning message for CORAL/IRM for detection --- examples/algorithms/initializer.py | 8 +++++--- examples/run_expt.py | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/algorithms/initializer.py b/examples/algorithms/initializer.py index 180e9ff5..a25afc8f 100644 --- a/examples/algorithms/initializer.py +++ b/examples/algorithms/initializer.py @@ -26,6 +26,8 @@ def initialize_algorithm(config, datasets, train_grouper): elif train_dataset.is_detection: # For detection, d_out is the number of classes d_out = train_dataset.n_classes + if config.algorithm in ['deepCORAL', 'IRM']: + raise ValueError(f'{config.algorithm} is not currently supported for detection datasets.') else: # For regression, we have one output per target dimension d_out = train_dataset.y_size @@ -35,7 +37,7 @@ def initialize_algorithm(config, datasets, train_grouper): loss = initialize_loss(config, d_out) metric = algo_log_metrics[config.algo_log_metric] - if config.algorithm=='ERM': + if config.algorithm == 'ERM': algorithm = ERM( config=config, d_out=d_out, @@ -54,7 +56,7 @@ def initialize_algorithm(config, datasets, train_grouper): metric=metric, n_train_steps=n_train_steps, is_group_in_train=is_group_in_train) - elif config.algorithm=='deepCORAL': + elif config.algorithm == 'deepCORAL': algorithm = DeepCORAL( config=config, d_out=d_out, @@ -62,7 +64,7 @@ def initialize_algorithm(config, datasets, train_grouper): loss=loss, metric=metric, n_train_steps=n_train_steps) - elif config.algorithm=='IRM': + elif config.algorithm == 'IRM': algorithm = IRM( config=config, d_out=d_out, diff --git a/examples/run_expt.py b/examples/run_expt.py index 8d8343b7..b37825ae 100644 --- a/examples/run_expt.py +++ b/examples/run_expt.py @@ -119,7 +119,7 @@ def main(): # For the GlobalWheat detection dataset, # we need to change the multiprocessing strategy or there will be - # too many open file descriptors + # too many open file descriptors. if config.dataset == 'globalwheat': torch.multiprocessing.set_sharing_strategy('file_system') From 35c92901d171be88eb3a62e1709b481e393d8669 Mon Sep 17 00:00:00 2001 From: aikanor Date: Fri, 28 May 2021 08:56:21 -0700 Subject: [PATCH 194/244] revert to better-generalizing version --- examples/configs/datasets.py | 2 +- wilds/datasets/encodetfbs_dataset.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index 671219a7..5c4dc667 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -118,7 +118,7 @@ 'val_metric_decreasing': False, 'optimizer': 'Adam', 'scheduler': 'MultiStepLR', - 'scheduler_kwargs': {'milestones':[6,9], 'gamma': 0.1}, # used to be 6, 9, with 12 epochs + 'scheduler_kwargs': {'milestones':[3,6], 'gamma': 0.1}, # used to be 6, 9, with 12 epochs # 'scheduler': 'linear_schedule_with_warmup', # 'scheduler_kwargs': {'num_warmup_steps': 800}, # about 160 minibatches per epoch 'batch_size': 256, diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index c3b7bc09..82ec6d81 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -322,7 +322,7 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' # Remove all-zero sequences from training. train_mask = (self._split_array == self._split_dict['train']) - allzeroes_mask = (self._y_array.nansum(axis=1) == 0).numpy() + allzeroes_mask = (self._y_array.sum(axis=1) == 0).numpy() keep_mask = keep_mask & ~(train_mask & allzeroes_mask) # Subsample the testing and validation indices, to speed up evaluation. From d05d8895e859d1ab627c8c958819226b495cdad1 Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Wed, 2 Jun 2021 11:48:07 -0700 Subject: [PATCH 195/244] change random to torch.randint to avoid issue with workers having the same random seed --- examples/transforms.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/transforms.py b/examples/transforms.py index dec68c2d..b1c9b97a 100644 --- a/examples/transforms.py +++ b/examples/transforms.py @@ -127,18 +127,18 @@ def standardize(x: torch.Tensor) -> torch.Tensor: return TF.normalize(x, mean, std) t_standardize = transforms.Lambda(lambda x: standardize(x)) - def random_d8(x: torch.Tensor) -> torch.Tensor: - angle = random.choice([0, 90, 180, 270]) + angles = [0, 90, 180, 270] + def random_rotation(x: torch.Tensor) -> torch.Tensor: + angle = angles[torch.randint(low=0, high=len(angles), size=(1,))] if angle > 0: x = TF.rotate(x, angle) - if random.random() < 0.5: - x = TF.hflip(x) return x - t_random_d8 = transforms.Lambda(lambda x: random_d8(x)) + t_random_rotation = transforms.Lambda(lambda x: random_rotation(x)) if is_training: transforms_ls = [ - t_random_d8, + t_random_rotation, + transforms.RandomHorizontalFlip(), transforms.ToTensor(), t_standardize, ] From 908334424c38fe94eb639c7773fa1fbc20bd7085 Mon Sep 17 00:00:00 2001 From: Etienne David Date: Thu, 24 Jun 2021 18:14:03 +0200 Subject: [PATCH 196/244] change globalwheat for the new splits --- wilds/datasets/globalwheat_dataset.py | 33 ++++++++++++++++++++++----- 1 file changed, 27 insertions(+), 6 deletions(-) diff --git a/wilds/datasets/globalwheat_dataset.py b/wilds/datasets/globalwheat_dataset.py index d4683dc3..fb4c3179 100644 --- a/wilds/datasets/globalwheat_dataset.py +++ b/wilds/datasets/globalwheat_dataset.py @@ -57,7 +57,11 @@ class GlobalWheatDataset(WILDSDataset): _versions_dict = { '0.9': { 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x8ba9122a41454997afdfb78762d390cf/contents/blob/', - 'compressed_size': 10_280_247_296}} + 'compressed_size': 10_280_247_296}, + '1.0': { + 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x8ba9122a41454997afdfb78762d390cf/contents/blob/', + 'compressed_size': 10_280_247_296} + } def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'): @@ -84,15 +88,32 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' else: train_data_df = pd.read_csv(self.root / f'official_train.csv') val_data_df = pd.read_csv(self.root / f'official_val.csv') - test_data_df = pd.read_csv(self.root / f'in-dist_test.csv') + test_data_df = pd.read_csv(self.root / f'in_dist_test.csv') elif split_scheme == "in-dist": if version == "0.9": - print("Warning: ood_with_subsampled_test is not available in 0.9") + print("Warning: in-dist is not available in 0.9") + else: + train_data_df = pd.read_csv(self.root / f'in_dist_train.csv') + val_data_df = pd.read_csv(self.root / f'official_val.csv') + test_data_df = pd.read_csv(self.root / f'in_dist_test.csv') + + elif split_scheme == "fixed-train": + if version == "0.9": + print("Warning: fixed-train is not available in 0.9") else: - train_data_df = pd.read_csv(self.root / f'in-dist_train.csv') + train_data_df = pd.read_csv(self.root / f'fixed_train_train.csv') + val_data_df = pd.read_csv(self.root / f'fixed_train_val.csv') + test_data_df = pd.read_csv(self.root / f'fixed_train_test.csv') + + elif split_scheme == "fixed-test": + if version == "0.9": + print("Warning: fixed-test is not available in 0.9") + else: + train_data_df = pd.read_csv(self.root / f'fixed_test_train.csv') val_data_df = pd.read_csv(self.root / f'official_val.csv') - test_data_df = pd.read_csv(self.root / f'in-dist_test.csv') + test_data_df = pd.read_csv(self.root / f'fixed_test_test.csv') + self._image_array = [] self._split_array, self._y_array, self._metadata_array = [], [], [] @@ -169,7 +190,7 @@ def _decode_string(box_string): return np.zeros((0,4)) else: try: - boxes = np.array([np.array([int(i) for i in box.split(" ")]) + boxes = np.array([np.array([int(eval(i)) for i in box.split(" ")]) for box in box_string.split(";")]) return boxes except: From 98c2504a9a1c925f991aef1d16580b941c65a203 Mon Sep 17 00:00:00 2001 From: Etienne David Date: Fri, 25 Jun 2021 19:21:40 +0200 Subject: [PATCH 197/244] add domain decoder --- wilds/datasets/globalwheat_dataset.py | 52 ++++++++++++++++++++++++++- 1 file changed, 51 insertions(+), 1 deletion(-) diff --git a/wilds/datasets/globalwheat_dataset.py b/wilds/datasets/globalwheat_dataset.py index fb4c3179..72e55476 100644 --- a/wilds/datasets/globalwheat_dataset.py +++ b/wilds/datasets/globalwheat_dataset.py @@ -7,6 +7,56 @@ from wilds.common.grouper import CombinatorialGrouper from wilds.common.metrics.all_metrics import DetectionAccuracy + +DATASETS_DECODER = {0: 'Rres_1', + 1: 'NMBU_2', + 2: 'NMBU_1', + 3: 'Arvalis_9', + 4: 'Arvalis_11', + 5: 'Arvalis_6', + 6: 'Arvalis_5', + 7: 'Arvalis_7', + 8: 'Inrae_1', + 9: 'Arvalis_10', + 10: 'Arvalis_12', + 11: 'Arvalis_4', + 12: 'Arvalis_3', + 13: 'Arvalis_2', + 14: 'Arvalis_1', + 15: 'Arvalis_8', + 16: 'Ethz_1', + 17: 'ULiège-GxABT_1', + 18: 'Utokyo_2', + 19: 'Utokyo_1', + 20: 'Utokyo_3', + 21: 'NAU_1', + 22: 'Ukyoto_1', + 23: 'NAU_3', + 24: 'NAU_2', + 25: 'ARC_1', + 26: 'UQ_11', + 27: 'UQ_10', + 28: 'UQ_9', + 29: 'UQ_8', + 30: 'UQ_6', + 31: 'Terraref_2', + 32: 'Terraref_1', + 33: 'KSU_4', + 34: 'KSU_3', + 35: 'KSU_2', + 36: 'KSU_1', + 37: 'CIMMYT_3', + 38: 'CIMMYT_2', + 39: 'CIMMYT_1', + 40: 'UQ_6', + 41: 'UQ_5', + 42: 'UQ_4', + 43: 'UQ_3', + 44: 'UQ_2', + 45: 'UQ_1', + 46: 'Usask_1' +} + class GlobalWheatDataset(WILDSDataset): """ The GlobalWheat-WILDS wheat head localization dataset. @@ -136,7 +186,7 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' } for boxes in all_boxes] self._y_array.extend(labels) - self._metadata_array.extend(list(df['domain'].values)) + self._metadata_array.extend([DATASETS_DECODER[int(item)] for item in df['domain'].values])) self._split_array = np.array(self._split_array) self._metadata_array = torch.tensor(self._metadata_array, From fbffb152216e417456931c1f714f93c800527167 Mon Sep 17 00:00:00 2001 From: Etienne David Date: Fri, 25 Jun 2021 19:29:36 +0200 Subject: [PATCH 198/244] add metadata of label 2 country --- wilds/datasets/globalwheat_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wilds/datasets/globalwheat_dataset.py b/wilds/datasets/globalwheat_dataset.py index 72e55476..b51380f2 100644 --- a/wilds/datasets/globalwheat_dataset.py +++ b/wilds/datasets/globalwheat_dataset.py @@ -186,7 +186,7 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' } for boxes in all_boxes] self._y_array.extend(labels) - self._metadata_array.extend([DATASETS_DECODER[int(item)] for item in df['domain'].values])) + self._metadata_array.extend([int(item) for item in df['domain'].values]) self._split_array = np.array(self._split_array) self._metadata_array = torch.tensor(self._metadata_array, From 3d7ad0ba2b6aa5e58e94cdde68af73a9bd41c957 Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Sat, 26 Jun 2021 15:50:41 -0700 Subject: [PATCH 199/244] add metadata to globalwheat --- examples/configs/datasets.py | 2 +- examples/train.py | 3 + wilds/datasets/globalwheat_dataset.py | 174 ++++++++++++++++++-------- 3 files changed, 128 insertions(+), 51 deletions(-) diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index 209f4be4..5b4c2241 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -304,7 +304,7 @@ 'pretrained': True }, 'loss_function': 'fasterrcnn_criterion', - 'groupby_fields': ['location_date_sensor'], + 'groupby_fields': ['session'], 'val_metric': 'detection_acc_avg_dom', 'val_metric_decreasing': False, 'algo_log_metric': None, # TODO diff --git a/examples/train.py b/examples/train.py index 37184b61..602f4f8d 100644 --- a/examples/train.py +++ b/examples/train.py @@ -32,6 +32,9 @@ def run_epoch(algorithm, dataset, general_logger, epoch, config, train): else: batch_results = algorithm.evaluate(batch) + import IPython + IPython.embed() + # These tensors are already detached, but we need to clone them again # Otherwise they don't get garbage collected properly in some versions # The extra detach is just for safety diff --git a/wilds/datasets/globalwheat_dataset.py b/wilds/datasets/globalwheat_dataset.py index b51380f2..50a6056e 100644 --- a/wilds/datasets/globalwheat_dataset.py +++ b/wilds/datasets/globalwheat_dataset.py @@ -7,55 +7,103 @@ from wilds.common.grouper import CombinatorialGrouper from wilds.common.metrics.all_metrics import DetectionAccuracy +SESSIONS = [ + 'Rres_1', + 'NMBU_2', + 'NMBU_1', + 'Arvalis_9', + 'Arvalis_11', + 'Arvalis_6', + 'Arvalis_5', + 'Arvalis_7', + 'Inrae_1', + 'Arvalis_10', + 'Arvalis_12', + 'Arvalis_4', + 'Arvalis_3', + 'Arvalis_2', + 'Arvalis_1', + 'Arvalis_8', + 'Ethz_1', + 'ULiège-GxABT_1', + 'Utokyo_2', + 'Utokyo_1', + 'Utokyo_3', + 'NAU_1', + 'Ukyoto_1', + 'NAU_3', + 'NAU_2', + 'ARC_1', + 'UQ_11', + 'UQ_10', + 'UQ_9', + 'UQ_8', + 'UQ_6', + 'Terraref_2', + 'Terraref_1', + 'KSU_4', + 'KSU_3', + 'KSU_2', + 'KSU_1', + 'CIMMYT_3', + 'CIMMYT_2', + 'CIMMYT_1', + 'UQ_6', + 'UQ_5', + 'UQ_4', + 'UQ_3', + 'UQ_2', + 'UQ_1', + 'Usask_1', +] -DATASETS_DECODER = {0: 'Rres_1', - 1: 'NMBU_2', - 2: 'NMBU_1', - 3: 'Arvalis_9', - 4: 'Arvalis_11', - 5: 'Arvalis_6', - 6: 'Arvalis_5', - 7: 'Arvalis_7', - 8: 'Inrae_1', - 9: 'Arvalis_10', - 10: 'Arvalis_12', - 11: 'Arvalis_4', - 12: 'Arvalis_3', - 13: 'Arvalis_2', - 14: 'Arvalis_1', - 15: 'Arvalis_8', - 16: 'Ethz_1', - 17: 'ULiège-GxABT_1', - 18: 'Utokyo_2', - 19: 'Utokyo_1', - 20: 'Utokyo_3', - 21: 'NAU_1', - 22: 'Ukyoto_1', - 23: 'NAU_3', - 24: 'NAU_2', - 25: 'ARC_1', - 26: 'UQ_11', - 27: 'UQ_10', - 28: 'UQ_9', - 29: 'UQ_8', - 30: 'UQ_6', - 31: 'Terraref_2', - 32: 'Terraref_1', - 33: 'KSU_4', - 34: 'KSU_3', - 35: 'KSU_2', - 36: 'KSU_1', - 37: 'CIMMYT_3', - 38: 'CIMMYT_2', - 39: 'CIMMYT_1', - 40: 'UQ_6', - 41: 'UQ_5', - 42: 'UQ_4', - 43: 'UQ_3', - 44: 'UQ_2', - 45: 'UQ_1', - 46: 'Usask_1' -} +COUNTRIES = [ + 'Switzerland', + 'UK', + 'Belgium', + 'Norway', + 'France', + 'Canada', + 'US', + 'Mexico', + 'Japan', + 'China', + 'Australia', + 'Sudan', +] + +LOCATIONS = [ + 'Baima', + 'Brookstead', + 'Ciudad Obregon', + 'Gatton', + 'Gembloux', + 'Gréoux', + 'KSU', + 'Kyoto', + 'Maricopa, AZ', + 'McAllister', + 'Mons', + 'NARO-Hokkaido', + 'NARO-Tsukuba', + 'NMBU', + 'Rothamsted', + 'Saskatchewan', + 'Toulouse', + 'Usask', + 'VLB', + 'VSC', + 'Wad Medani', +] + +STAGES = [ + 'Filling', + 'Filling - Ripening', + 'multiple', + 'Post-flowering', + 'Post-Flowering', + 'Ripening', +] class GlobalWheatDataset(WILDSDataset): """ @@ -191,10 +239,36 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' self._split_array = np.array(self._split_array) self._metadata_array = torch.tensor(self._metadata_array, dtype=torch.long).unsqueeze(1) - self._metadata_fields = ['location_date_sensor'] + self._metadata_array = torch.cat( + (self._metadata_array, + torch.zeros( + (len(self._metadata_array), 3), + dtype=torch.long)), + dim=1) + + domain_df = pd.read_csv(self.root / 'metadata_domain.csv', sep=';') + for session_idx, session_name in enumerate(SESSIONS): + idx = pd.Index(domain_df['name']).get_loc(session_name) + country = domain_df.loc[idx, 'country'] + location = domain_df.loc[idx, 'location'] + stage = domain_df.loc[idx, 'development_stage'] + + session_mask = (self._metadata_array[:, 0] == session_idx) + self._metadata_array[session_mask, 1] = COUNTRIES.index(country) + self._metadata_array[session_mask, 2] = LOCATIONS.index(location) + self._metadata_array[session_mask, 3] = STAGES.index(stage) + + self._metadata_fields = ['session', 'country', 'location', 'stage'] + self._metadata_map = { + 'session': SESSIONS, + 'country': COUNTRIES, + 'location': LOCATIONS, + 'stage': STAGES, + } + self._eval_grouper = CombinatorialGrouper( dataset=self, - groupby_fields=['location_date_sensor']) + groupby_fields=['session']) self._metric = DetectionAccuracy() self._collate = GlobalWheatDataset._collate_fn From 5ac5b3a5d806c37b3ba3cf9557ad823c984c0788 Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Mon, 28 Jun 2021 11:09:02 -0700 Subject: [PATCH 200/244] update with new domain ordering --- examples/train.py | 5 +- wilds/datasets/globalwheat_dataset.py | 66 +++++++++++++-------------- 2 files changed, 34 insertions(+), 37 deletions(-) diff --git a/examples/train.py b/examples/train.py index 602f4f8d..5b0b591e 100644 --- a/examples/train.py +++ b/examples/train.py @@ -30,10 +30,7 @@ def run_epoch(algorithm, dataset, general_logger, epoch, config, train): if train: batch_results = algorithm.update(batch) else: - batch_results = algorithm.evaluate(batch) - - import IPython - IPython.embed() + batch_results = algorithm.evaluate(batch) # These tensors are already detached, but we need to clone them again # Otherwise they don't get garbage collected properly in some versions diff --git a/wilds/datasets/globalwheat_dataset.py b/wilds/datasets/globalwheat_dataset.py index 50a6056e..cd943685 100644 --- a/wilds/datasets/globalwheat_dataset.py +++ b/wilds/datasets/globalwheat_dataset.py @@ -8,53 +8,53 @@ from wilds.common.metrics.all_metrics import DetectionAccuracy SESSIONS = [ - 'Rres_1', - 'NMBU_2', - 'NMBU_1', - 'Arvalis_9', - 'Arvalis_11', - 'Arvalis_6', + 'Arvalis_1', + 'Arvalis_2', + 'Arvalis_3', + 'Arvalis_4', 'Arvalis_5', + 'Arvalis_6', 'Arvalis_7', - 'Inrae_1', + 'Arvalis_8', + 'Arvalis_9', 'Arvalis_10', + 'Arvalis_11', 'Arvalis_12', - 'Arvalis_4', - 'Arvalis_3', - 'Arvalis_2', - 'Arvalis_1', - 'Arvalis_8', - 'Ethz_1', + 'ETHZ_1', + 'Inrae_1', + 'NMBU_1', + 'NMBU_2', + 'Rres_1', 'ULiège-GxABT_1', - 'Utokyo_2', 'Utokyo_1', + 'Utokyo_2', 'Utokyo_3', - 'NAU_1', 'Ukyoto_1', - 'NAU_3', + 'NAU_1', 'NAU_2', + 'NAU_3', 'ARC_1', - 'UQ_11', - 'UQ_10', - 'UQ_9', - 'UQ_8', + 'UQ_1', + 'UQ_2', + 'UQ_3', + 'UQ_4', + 'UQ_5', 'UQ_6', - 'Terraref_2', + 'UQ_7', + 'UQ_8', + 'UQ_9', + 'UQ_10', + 'UQ_11', 'Terraref_1', - 'KSU_4', - 'KSU_3', - 'KSU_2', + 'Terraref_2', 'KSU_1', - 'CIMMYT_3', - 'CIMMYT_2', + 'KSU_2', + 'KSU_3', + 'KSU_4', 'CIMMYT_1', - 'UQ_6', - 'UQ_5', - 'UQ_4', - 'UQ_3', - 'UQ_2', - 'UQ_1', - 'Usask_1', + 'CIMMYT_2', + 'CIMMYT_3', + 'Usask_1' ] COUNTRIES = [ From 84c677a203e5dec8761d1d2a66f8fcdc7f15c04c Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Wed, 30 Jun 2021 10:54:58 -0700 Subject: [PATCH 201/244] camelyon processing cleanup --- .../camelyon17/generate_all_patch_coords.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/dataset_preprocessing/camelyon17/generate_all_patch_coords.py b/dataset_preprocessing/camelyon17/generate_all_patch_coords.py index 69cca6cf..60a509ee 100644 --- a/dataset_preprocessing/camelyon17/generate_all_patch_coords.py +++ b/dataset_preprocessing/camelyon17/generate_all_patch_coords.py @@ -109,11 +109,7 @@ def _record_patches(center_size, slide, slide_map, patch_level, mask_level, tumor_mask, tissue_mask, normal_mask, tumor_threshold, - tumor_sel_ratio, - tumor_sel_max, - normal_threshold, - normal_sel_ratio, - normal_sel_max, + normal_threshold, **args): """ Extract all tumor and non-tumor patches from a slide, using the given masks. @@ -197,11 +193,7 @@ def generate_file(patient, node, xml_path, slide_path, folder_path): 'mask_level' : MASK_LEVEL, 'center_size' : CENTER_SIZE, 'tumor_threshold' : 0, - 'tumor_sel_ratio' : 1, - 'tumor_sel_max' : 100000, 'normal_threshold' : 0.2, - 'normal_sel_ratio' : 1, - 'normal_sel_max' : 100000, 'mask_folder_path' : folder_path, 'make_map' : True } From bd3c28afecb056adeb9945e3381fe0a9b860e9c7 Mon Sep 17 00:00:00 2001 From: Unknown Date: Thu, 1 Jul 2021 13:30:48 +0200 Subject: [PATCH 202/244] fix globalwheat typo to make average metric works --- wilds/datasets/globalwheat_dataset.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/wilds/datasets/globalwheat_dataset.py b/wilds/datasets/globalwheat_dataset.py index cd943685..9a93de7d 100644 --- a/wilds/datasets/globalwheat_dataset.py +++ b/wilds/datasets/globalwheat_dataset.py @@ -123,7 +123,7 @@ class GlobalWheatDataset(WILDSDataset): Output (y): y is a n x 4-dimensional vector where each line represents a box coordinate (x_min, y_min, x_max, y_max) Metadata: - Each image is annotated with the ID of the domain (location_date_sensor) it came from (integer from 0 to 46). + Each image is annotated with the ID of the domain (session) it came from (integer from 0 to 46). Website: http://www.global-wheat.com/ Original publication: @@ -295,14 +295,15 @@ def eval(self, y_pred, y_true, metadata): detection_accs = [] for k, v in results.items(): - if k.startswith('detection_acc_location_date_sensor:'): + if k.startswith('detection_acc_session:'): d = k.split(':')[1] - count = results[f'count_location_date_sensor:{d}'] + count = results[f'count_session:{d}'] if count > 0: detection_accs.append(v) detection_acc_avg_dom = np.array(detection_accs).mean() + print("DEBUG ", detection_acc_avg_dom) results['detection_acc_avg_dom'] = detection_acc_avg_dom - results_str = f'Average detection_acc across domains: {detection_acc_avg_dom:.3f}\n' + results_str + results_str = f'Average detection_acc across session: {detection_acc_avg_dom:.3f}\n' + results_str return results, results_str @staticmethod From 5770a910f7c5f9a09fdaa695fe56967fec6db6a1 Mon Sep 17 00:00:00 2001 From: Unknown Date: Thu, 1 Jul 2021 13:36:56 +0200 Subject: [PATCH 203/244] remove debug line --- wilds/datasets/globalwheat_dataset.py | 1 - 1 file changed, 1 deletion(-) diff --git a/wilds/datasets/globalwheat_dataset.py b/wilds/datasets/globalwheat_dataset.py index 9a93de7d..6aba9b60 100644 --- a/wilds/datasets/globalwheat_dataset.py +++ b/wilds/datasets/globalwheat_dataset.py @@ -301,7 +301,6 @@ def eval(self, y_pred, y_true, metadata): if count > 0: detection_accs.append(v) detection_acc_avg_dom = np.array(detection_accs).mean() - print("DEBUG ", detection_acc_avg_dom) results['detection_acc_avg_dom'] = detection_acc_avg_dom results_str = f'Average detection_acc across session: {detection_acc_avg_dom:.3f}\n' + results_str return results, results_str From d20bac2bde1dfdbef1720d2e388d010353ba99ee Mon Sep 17 00:00:00 2001 From: Unknown Date: Thu, 1 Jul 2021 13:43:43 +0200 Subject: [PATCH 204/244] add bundle + update description --- wilds/datasets/globalwheat_dataset.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/wilds/datasets/globalwheat_dataset.py b/wilds/datasets/globalwheat_dataset.py index 6aba9b60..610b942a 100644 --- a/wilds/datasets/globalwheat_dataset.py +++ b/wilds/datasets/globalwheat_dataset.py @@ -112,9 +112,6 @@ class GlobalWheatDataset(WILDSDataset): The current version does not contain test or validation labels, as it is being used in a currently-running competition. - After the competition concludes in July 2021, we will update the dataset to contain the - final splits with test and validation labels, and add the dataset to the official WILDS - benchmark. Supported `split_scheme`: - 'official' @@ -139,6 +136,14 @@ class GlobalWheatDataset(WILDSDataset): note = {Publisher: AAAS}, pages = {3521852}, } + @misc{david2021global, + title={Global Wheat Head Dataset 2021: more diversity to improve the benchmarking of wheat head localization methods}, + author={Etienne David and Mario Serouart and Daniel Smith and Simon Madec and Kaaviya Velumani and Shouyang Liu and Xu Wang and Francisco Pinto Espinosa and Shahameh Shafiee and Izzat S. A. Tahir and Hisashi Tsujimoto and Shuhei Nasuda and Bangyou Zheng and Norbert Kichgessner and Helge Aasen and Andreas Hund and Pouria Sadhegi-Tehran and Koichi Nagasawa and Goro Ishikawa and Sébastien Dandrifosse and Alexis Carlier and Benoit Mercatoris and Ken Kuroki and Haozhou Wang and Masanori Ishii and Minhajul A. Badhon and Curtis Pozniak and David Shaner LeBauer and Morten Lilimo and Jesse Poland and Scott Chapman and Benoit de Solan and Frédéric Baret and Ian Stavness and Wei Guo}, + year={2021}, + eprint={2105.07660}, + archivePrefix={arXiv}, + primaryClass={cs.CV} + } License: This dataset is distributed under the MIT license. """ @@ -157,8 +162,8 @@ class GlobalWheatDataset(WILDSDataset): 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x8ba9122a41454997afdfb78762d390cf/contents/blob/', 'compressed_size': 10_280_247_296}, '1.0': { - 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x8ba9122a41454997afdfb78762d390cf/contents/blob/', - 'compressed_size': 10_280_247_296} + 'download_url': 'https://worksheets.codalab.org/bundles/0x03b0584cb00d4ea987aa3269aa2fd2b4/contents/blob/' + } } def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'): From 1f3160ed114e9915f7a8ad09329d7265171b03f0 Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Sat, 3 Jul 2021 01:16:49 -0700 Subject: [PATCH 205/244] fix url --- wilds/datasets/globalwheat_dataset.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/wilds/datasets/globalwheat_dataset.py b/wilds/datasets/globalwheat_dataset.py index 610b942a..c1e25ba6 100644 --- a/wilds/datasets/globalwheat_dataset.py +++ b/wilds/datasets/globalwheat_dataset.py @@ -137,7 +137,7 @@ class GlobalWheatDataset(WILDSDataset): pages = {3521852}, } @misc{david2021global, - title={Global Wheat Head Dataset 2021: more diversity to improve the benchmarking of wheat head localization methods}, + title={Global Wheat Head Dataset 2021: more diversity to improve the benchmarking of wheat head localization methods}, author={Etienne David and Mario Serouart and Daniel Smith and Simon Madec and Kaaviya Velumani and Shouyang Liu and Xu Wang and Francisco Pinto Espinosa and Shahameh Shafiee and Izzat S. A. Tahir and Hisashi Tsujimoto and Shuhei Nasuda and Bangyou Zheng and Norbert Kichgessner and Helge Aasen and Andreas Hund and Pouria Sadhegi-Tehran and Koichi Nagasawa and Goro Ishikawa and Sébastien Dandrifosse and Alexis Carlier and Benoit Mercatoris and Ken Kuroki and Haozhou Wang and Masanori Ishii and Minhajul A. Badhon and Curtis Pozniak and David Shaner LeBauer and Morten Lilimo and Jesse Poland and Scott Chapman and Benoit de Solan and Frédéric Baret and Ian Stavness and Wei Guo}, year={2021}, eprint={2105.07660}, @@ -162,9 +162,9 @@ class GlobalWheatDataset(WILDSDataset): 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x8ba9122a41454997afdfb78762d390cf/contents/blob/', 'compressed_size': 10_280_247_296}, '1.0': { - 'download_url': 'https://worksheets.codalab.org/bundles/0x03b0584cb00d4ea987aa3269aa2fd2b4/contents/blob/' - } - } + 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x03b0584cb00d4ea987aa3269aa2fd2b4/contents/blob/', + 'compressed_size': None} + } def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'): From b0e8aca02705f306157b9267cb9157180dca0005 Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Sat, 3 Jul 2021 16:28:59 -0700 Subject: [PATCH 206/244] add size for globalwheat --- wilds/datasets/globalwheat_dataset.py | 57 +++++++-------------------- 1 file changed, 15 insertions(+), 42 deletions(-) diff --git a/wilds/datasets/globalwheat_dataset.py b/wilds/datasets/globalwheat_dataset.py index c1e25ba6..1c4dbf49 100644 --- a/wilds/datasets/globalwheat_dataset.py +++ b/wilds/datasets/globalwheat_dataset.py @@ -108,10 +108,7 @@ class GlobalWheatDataset(WILDSDataset): """ The GlobalWheat-WILDS wheat head localization dataset. - This is a modified version of the original Global Wheat Head Dataset 2021. - - The current version does not contain test or validation labels, as it is being used in a - currently-running competition. + This is a modified version of the original Global Wheat Head Dataset 2021. Supported `split_scheme`: - 'official' @@ -149,21 +146,10 @@ class GlobalWheatDataset(WILDSDataset): """ _dataset_name = 'globalwheat' - - # Version 0.9 corresponds to the final dataset, but without the validation and test labels, - # since it is being used in a currently-running competition (http://www.global-wheat.com/). - # Users can submit their val+test predictions to the competition to obtain an estimate of - # held-out performance computed on a fraction of those predictions; - # please see the tutorial at https://www.aicrowd.com/challenges/global-wheat-challenge-2021. - # We will update the dataset to include these labels and update the splits after the - # competition ends in July 2021. _versions_dict = { - '0.9': { - 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x8ba9122a41454997afdfb78762d390cf/contents/blob/', - 'compressed_size': 10_280_247_296}, '1.0': { 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x03b0584cb00d4ea987aa3269aa2fd2b4/contents/blob/', - 'compressed_size': None} + 'compressed_size': 10_286_874_624} } def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'): @@ -185,38 +171,25 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' val_data_df = pd.read_csv(self.root / f'official_val.csv') test_data_df = pd.read_csv(self.root / f'official_test.csv') - elif split_scheme == "ood_with_subsampled_test": - if version == "0.9": - print("Warning: ood_with_subsampled_test is not available in 0.9") - else: - train_data_df = pd.read_csv(self.root / f'official_train.csv') - val_data_df = pd.read_csv(self.root / f'official_val.csv') - test_data_df = pd.read_csv(self.root / f'in_dist_test.csv') + elif split_scheme == "official_with_subsampled_test": + train_data_df = pd.read_csv(self.root / f'official_train.csv') + val_data_df = pd.read_csv(self.root / f'official_val.csv') + test_data_df = pd.read_csv(self.root / f'fixed_test_test.csv') elif split_scheme == "in-dist": - if version == "0.9": - print("Warning: in-dist is not available in 0.9") - else: - train_data_df = pd.read_csv(self.root / f'in_dist_train.csv') - val_data_df = pd.read_csv(self.root / f'official_val.csv') - test_data_df = pd.read_csv(self.root / f'in_dist_test.csv') + train_data_df = pd.read_csv(self.root / f'in_dist_train.csv') + val_data_df = pd.read_csv(self.root / f'official_val.csv') + test_data_df = pd.read_csv(self.root / f'in_dist_test.csv') elif split_scheme == "fixed-train": - if version == "0.9": - print("Warning: fixed-train is not available in 0.9") - else: - train_data_df = pd.read_csv(self.root / f'fixed_train_train.csv') - val_data_df = pd.read_csv(self.root / f'fixed_train_val.csv') - test_data_df = pd.read_csv(self.root / f'fixed_train_test.csv') + train_data_df = pd.read_csv(self.root / f'fixed_train_train.csv') + val_data_df = pd.read_csv(self.root / f'fixed_train_val.csv') + test_data_df = pd.read_csv(self.root / f'fixed_train_test.csv') elif split_scheme == "fixed-test": - if version == "0.9": - print("Warning: fixed-test is not available in 0.9") - else: - train_data_df = pd.read_csv(self.root / f'fixed_test_train.csv') - val_data_df = pd.read_csv(self.root / f'official_val.csv') - test_data_df = pd.read_csv(self.root / f'fixed_test_test.csv') - + train_data_df = pd.read_csv(self.root / f'fixed_test_train.csv') + val_data_df = pd.read_csv(self.root / f'official_val.csv') + test_data_df = pd.read_csv(self.root / f'fixed_test_test.csv') self._image_array = [] self._split_array, self._y_array, self._metadata_array = [], [], [] From 647922e62d8c15fe2f5e8fd747451d9351b62c1d Mon Sep 17 00:00:00 2001 From: Tony Lee Date: Sun, 4 Jul 2021 14:12:08 -0700 Subject: [PATCH 207/244] update best hyperparameters --- examples/configs/datasets.py | 2 +- examples/run_expt.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index 5b4c2241..014e91a0 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -313,7 +313,7 @@ 'scheduler': None, 'batch_size': 4, 'lr': 1e-5, - 'weight_decay': 1e-4, + 'weight_decay': 0, 'n_epochs': 10, 'loader_kwargs': { 'num_workers': 1, diff --git a/examples/run_expt.py b/examples/run_expt.py index cf6f9568..1cc8a08c 100644 --- a/examples/run_expt.py +++ b/examples/run_expt.py @@ -8,6 +8,9 @@ import sys from collections import defaultdict +# TODO: This is needed to test the WILDS package locally. Remove later -Tony +sys.path.insert(1, os.path.join(sys.path[0], '..')) + import wilds from wilds.common.data_loaders import get_train_loader, get_eval_loader from wilds.common.grouper import CombinatorialGrouper From 3a157f2844387bb11d6c6f935da0e7742bf6ba83 Mon Sep 17 00:00:00 2001 From: Tony Lee Date: Sun, 4 Jul 2021 14:13:16 -0700 Subject: [PATCH 208/244] cleanup --- examples/run_expt.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/examples/run_expt.py b/examples/run_expt.py index 1cc8a08c..cf6f9568 100644 --- a/examples/run_expt.py +++ b/examples/run_expt.py @@ -8,9 +8,6 @@ import sys from collections import defaultdict -# TODO: This is needed to test the WILDS package locally. Remove later -Tony -sys.path.insert(1, os.path.join(sys.path[0], '..')) - import wilds from wilds.common.data_loaders import get_train_loader, get_eval_loader from wilds.common.grouper import CombinatorialGrouper From 68a50ca29f6535b790675c17c2a77ca068df0221 Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Tue, 6 Jul 2021 21:16:27 -0700 Subject: [PATCH 209/244] n_groups tweak --- examples/configs/datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index 5c4dc667..81d1ff4e 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -11,7 +11,7 @@ 'lr': 1e-5, 'weight_decay': 0.01, 'n_epochs': 3, - 'n_groups_per_batch': 2, + 'n_groups_per_batch': 4, 'irm_lambda': 1.0, 'coral_penalty_weight': 1.0, 'loader_kwargs': { From 6e546fea1069cede8cfeb82face7abb84382a4e6 Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Tue, 6 Jul 2021 21:17:19 -0700 Subject: [PATCH 210/244] encode settings --- examples/configs/datasets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index 81d1ff4e..052d9629 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -125,10 +125,10 @@ 'lr': 1e-3, 'weight_decay': 1e-4, 'n_epochs': 12, - 'n_groups_per_batch': 2, + 'n_groups_per_batch': 4, 'algo_log_metric': 'multitask_binary_accuracy', 'irm_lambda': 100.0, - # 'coral_penalty_weight': 0.1, + 'coral_penalty_weight': 0.1, }, 'fmow': { 'split_scheme': 'official', From 4348d8af3dbd22e430d7008a37d3aa27822b5293 Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Tue, 6 Jul 2021 22:10:55 -0700 Subject: [PATCH 211/244] clean up unused ENCODE metrics --- examples/configs/supported.py | 3 +- wilds/__init__.py | 4 +- wilds/common/metrics/all_metrics.py | 57 ---------------------------- wilds/datasets/encodetfbs_dataset.py | 5 +-- 4 files changed, 4 insertions(+), 65 deletions(-) diff --git a/examples/configs/supported.py b/examples/configs/supported.py index 35c44741..1a6b0662 100644 --- a/examples/configs/supported.py +++ b/examples/configs/supported.py @@ -18,8 +18,7 @@ 'mse': MSE(), 'multitask_accuracy': MultiTaskAccuracy(prediction_fn=multiclass_logits_to_pred), 'multitask_binary_accuracy': MultiTaskAccuracy(prediction_fn=binary_logits_to_pred), - 'multitask_avgprec': MultiTaskAveragePrecision(prediction_fn=None), - # 'multitask_preven': MultiTaskPREven(prediction_fn=None), + 'multitask_avgprec': MultiTaskAveragePrecision(prediction_fn=None), None: None, } diff --git a/wilds/__init__.py b/wilds/__init__.py index 0a22c780..7a9ed877 100644 --- a/wilds/__init__.py +++ b/wilds/__init__.py @@ -9,8 +9,7 @@ 'ogb-molpcba', 'poverty', 'fmow', - 'py150', - 'encode-tfbs' + 'py150', ] additional_datasets = [ @@ -19,6 +18,7 @@ 'yelp', 'bdd100k', 'sqf', + 'encode-tfbs' ] supported_datasets = benchmark_datasets + additional_datasets diff --git a/wilds/common/metrics/all_metrics.py b/wilds/common/metrics/all_metrics.py index 98fa2848..e3858aec 100644 --- a/wilds/common/metrics/all_metrics.py +++ b/wilds/common/metrics/all_metrics.py @@ -104,40 +104,6 @@ def _compute_group_wise(self, y_pred, y_true, g, n_groups): def worst(self, metrics): return minimum(metrics) -# Break-even point of precision and recall. This is approx equal to average precision, and added due to currently open numerical issues with having zero true positives in a batch (see https://github.com/scikit-learn/scikit-learn/issues/8245 ). -def calc_PREven(y_true, y_score): - numpos = np.sum(y_true == 1) - top_ndces = np.argsort(y_score)[::-1] - ytr_ord = y_true[top_ndces] - m = {x: np.sum(ytr_ord[:numpos] == x) for x in np.unique(ytr_ord[:numpos])} - p = m[1] if 1 in m else 0 - if (len(m) == 0) or (0 not in m): - return None - else: - return 1.0*p/(m[0]+p) - -class MultiTaskPREven(MultiTaskMetric): - def __init__(self, prediction_fn=None, name=None): - self.prediction_fn = prediction_fn - if name is None: - name = f'preven' - super().__init__(name=name) - - def _compute_flattened(self, flattened_y_pred, flattened_y_true): - if self.prediction_fn is not None: - flattened_y_pred = self.prediction_fn(flattened_y_pred) - ytr = np.array(flattened_y_true.squeeze().detach().cpu().numpy() > 0) - ypr = flattened_y_pred.squeeze().detach().cpu().numpy() - score = calc_PREven(ytr, ypr) - to_ret = torch.tensor(score).to(flattened_y_pred.device) - return to_ret - - def _compute(self, y_pred, y_true): - return self._compute_flattened(y_pred, y_true) - - def worst(self, metrics): - return minimum(metrics) - class Recall(Metric): def __init__(self, prediction_fn=None, name=None, average='binary'): self.prediction_fn = prediction_fn @@ -157,29 +123,6 @@ def _compute(self, y_pred, y_true): def worst(self, metrics): return minimum(metrics) -class AveragePrecision(Metric): - def __init__(self, prediction_fn=None, name=None, average='macro'): - self.prediction_fn = prediction_fn - if name is None: - name = f'avgprec' - if average is not None: - name+=f'-{average}' - self.average = average - super().__init__(name=name) - - def _compute(self, y_pred, y_true): - if self.prediction_fn is not None: - y_pred = self.prediction_fn(y_pred) - score = sklearn.metrics.average_precision_score( - np.array(y_true.squeeze().detach().cpu().numpy() > 0), - y_pred.squeeze().detach().cpu().numpy(), - average=self.average - ) - return torch.tensor(score) - - def worst(self, metrics): - return minimum(metrics) - class F1(Metric): def __init__(self, prediction_fn=None, name=None, average='binary'): self.prediction_fn = prediction_fn diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py index 82ec6d81..9fe44e98 100644 --- a/wilds/datasets/encodetfbs_dataset.py +++ b/wilds/datasets/encodetfbs_dataset.py @@ -345,7 +345,7 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' self._metadata_df = self._metadata_df[keep_mask] self._split_array = self._split_array[keep_mask] - self._y_array = self._y_array[keep_mask] + self._y_array = self._y_array[keep_mask] self._all_chroms = sorted(list({chrom for _, d in splits.items() for chrom in d['chroms']})) self._all_celltypes = sorted(list({chrom for _, d in splits.items() for chrom in d['celltypes']})) @@ -418,9 +418,6 @@ def get_input(self, idx, window_size=12800): seq_this = self._seq_bp[this_metadata['chr']][interval_start:interval_end] dnase_bw = self._dnase_allcelltypes[this_metadata['celltype']] dnase_this = np.nan_to_num(dnase_bw.values(chrom, interval_start, interval_end, numpy=True)) -# assert(np.isnan(seq_this).sum() == 0) -# assert(np.isnan(dnase_this).sum() == 0) -# dnase_this = self.norm_signal(dnase_this, this_metadata['celltype']) return torch.tensor(np.column_stack( [seq_this, dnase_this] From 439d741624c69f6f5f908ec42b9c0dac85237188 Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Tue, 6 Jul 2021 22:15:07 -0700 Subject: [PATCH 212/244] dataset name change --- dataset_preprocessing/encode-tfbs/README.md | 26 -- .../encode-tfbs/prep_accessibility.py | 180 -------- .../encode-tfbs/prep_datasets.ipynb | 279 ------------ .../encode-tfbs/prep_metadata_labels.py | 146 ------ .../encode-tfbs/prep_sequence.py | 131 ------ examples/configs/datasets.py | 2 +- wilds/__init__.py | 4 +- wilds/datasets/encodetfbs_dataset.py | 430 ------------------ wilds/get_dataset.py | 6 +- 9 files changed, 6 insertions(+), 1198 deletions(-) delete mode 100644 dataset_preprocessing/encode-tfbs/README.md delete mode 100644 dataset_preprocessing/encode-tfbs/prep_accessibility.py delete mode 100644 dataset_preprocessing/encode-tfbs/prep_datasets.ipynb delete mode 100644 dataset_preprocessing/encode-tfbs/prep_metadata_labels.py delete mode 100644 dataset_preprocessing/encode-tfbs/prep_sequence.py delete mode 100644 wilds/datasets/encodetfbs_dataset.py diff --git a/dataset_preprocessing/encode-tfbs/README.md b/dataset_preprocessing/encode-tfbs/README.md deleted file mode 100644 index 1ef362b6..00000000 --- a/dataset_preprocessing/encode-tfbs/README.md +++ /dev/null @@ -1,26 +0,0 @@ -## ENCODE-TFBS-wilds feature generation and preprocessing - -#### Requirements -- pyBigWig - -#### Instructions to create Codalab bundle - -1. Download the human genome sequence (hg19 assembly) in FASTA format from http://hgdownload.cse.ucsc.edu/goldenpath/hg19/bigZips/hg19.fa.gz and extract it into `SEQUENCE_PATH`. - -2. Run `python prep_sequence.py --seq_path SEQUENCE_PATH --output_dir OUTPUT_DIR` to write the fasta file found in `SEQUENCE_PATH` to a numpy array archive in `OUTPUT_DIR`. - -3. Download the DNase accessibility data. This consists of whole-genome DNase files in bigwig format from https://guanfiles.dcmb.med.umich.edu/Leopard/dnase_bigwig/. These are saved with filename `DNASE..fc.signal.bigwig`. - -4. Run `prep_accessibility.py`. - -5. Download the labels from the challenge into a label directory `labels/` created for this purpose: - - The training chromosome labels for the challenge's training cell types from https://www.synapse.org/#!Synapse:syn7413983 for the relevant transcription factor ( https://www.synapse.org/#!Synapse:syn7415202 for the TF MAX, , downloaded as MAX.train.labels.tsv.gz ). - - The training chromosome labels for the challenge's evaluation cell type (liver) from https://www.synapse.org/#!Synapse:syn8077511 for the relevant transcription factor ( https://www.synapse.org/#!Synapse:syn8077648 for the TF MAX, downloaded as MAX.train_wc.labels.tsv.gz ). - - The validation chromosome labels for the challenge's training cell types from https://www.synapse.org/#!Synapse:syn8441154 for the relevant transcription factor ( https://www.synapse.org/#!Synapse:syn8442103 for the TF MAX, downloaded as MAX.val.labels.tsv.gz ). - - The validation chromosome labels for the challenge's evaluation cell type (liver) from https://www.synapse.org/#!Synapse:syn8442975 for the relevant transcription factor ( https://www.synapse.org/#!Synapse:syn8443021 for the TF MAX, downloaded as MAX.test.labels.tsv.gz ). - -6. Run `prep_metadata_labels.py`. - - -#### Instructions to run on Codalab bundle -7. \ No newline at end of file diff --git a/dataset_preprocessing/encode-tfbs/prep_accessibility.py b/dataset_preprocessing/encode-tfbs/prep_accessibility.py deleted file mode 100644 index 65d66341..00000000 --- a/dataset_preprocessing/encode-tfbs/prep_accessibility.py +++ /dev/null @@ -1,180 +0,0 @@ -# Adapted from https://github.com/GuanLab/Leopard/blob/master/data/quantile_normalize_bigwig.py - -import argparse, time -import numpy as np -import pyBigWig - -# Human chromosomes in hg19 -chrom_sizes = {'chr1': 249250621, 'chr10': 135534747, 'chr11': 135006516, 'chr12': 133851895, 'chr13': 115169878, 'chr14': 107349540, 'chr15': 102531392, 'chr16': 90354753, 'chr17': 81195210, 'chr18': 78077248, 'chr19': 59128983, 'chr2': 243199373, 'chr20': 63025520, 'chr21': 48129895, 'chr22': 51304566, 'chr3': 198022430, 'chr4': 191154276, 'chr5': 180915260, 'chr6': 171115067, 'chr7': 159138663, 'chr8': 146364022, 'chr9': 141213431, 'chrX': 155270560} - - -def qn_sample_to_array( - input_celltypes, - input_chroms=None, - subsampling_ratio=1000, - data_pfx = '/users/abalsubr/wilds/examples/data/encode-tfbs_v1.0/' -): - itime = time.time() - if input_chroms is None: - input_chroms = chrom_sizes.keys() - qn_chrom_sizes = { k: chrom_sizes[k] for k in input_chroms } - # chromosome-specific subsampling seeds - chr_to_seed = {} - i = 0 - for the_chr in qn_chrom_sizes: - chr_to_seed[the_chr] = i - i += 1 - - # subsampling; multiple replicates are added - sample_len = np.ceil(np.array(list(qn_chrom_sizes.values()))/subsampling_ratio).astype(int) - sample = np.zeros(sum(sample_len)) - start = 0 - j = 0 - for the_chr in qn_chrom_sizes: - np.random.seed(chr_to_seed[the_chr]) - for ct in input_celltypes: - path = data_pfx + 'DNASE.{}.fc.signal.bigwig'.format(ct) - bw = pyBigWig.open(path) - signal = np.nan_to_num(np.array(bw.values(the_chr, 0, qn_chrom_sizes[the_chr]))) - index = np.random.randint(0, len(signal), sample_len[j]) - sample[start:(start+sample_len[j])] += (1.0/len(input_celltypes))*signal[index] - start += sample_len[j] - j += 1 - print(the_chr, ct, time.time() - itime) - - if np.any(np.isnan(sample)): - print('wtf! sample contains nan!') - sample.sort() - np.save(data_pfx + "qn.{}.npy".format('.'.join(input_celltypes)), sample) - - -# quantile normalization via numpy inter/extra-polation -def anchor(input_data, sample, ref): # input 1d array - sample.sort() - ref.sort() - # 0. create the mapping function - index = np.array(np.where(np.diff(sample) != 0)) + 1 - index = index.flatten() - x = np.concatenate((np.zeros(1), sample[index])) # domain - y = np.zeros(len(x)) # codomain - for i in np.arange(0,len(index)-1, 1): - start = index[i] - end = index[i+1] - y[i+1] = np.mean(ref[start:end]) - i += 1 - start = index[i] - end = len(ref) - y[i+1] = np.mean(ref[start:end]) - # 1. interpolate - output = np.interp(input_data, x, y) - # 2. extrapolate - degree = 1 # degree of the fitting polynomial - num = 10 # number of positions for extrapolate - f1 = np.poly1d(np.polyfit(sample[-num:],ref[-num:],degree)) -# f2=np.poly1d(np.polyfit(sample[:num],ref[:num],degree)) - output[input_data > sample[-1]] = f1(input_data[input_data > sample[-1]]) -# output[input_data20) minutes\n", - "print(\"Saving npz archive...\")\n", - "np.savez_compressed('codalab_archive/sequence', **kw_dict)\n", - "print(time.time() - itime)\n", - "\n", - "# # Save as npy arrays\n", - "# itime = time.time()\n", - "# for chrom in kw_dict:\n", - "# np.save('sequence/{}.npy'.format(chrom), kw_dict[chrom])\n", - "# print(chrom, time.time() - itime)\n", - "\n", - "npz_archive = np.load('codalab_archive/sequence.npz')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## DNase" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "liver 0.006468534469604492\n", - "chr1 8.260387659072876\n", - "chr1 13.276052474975586\n", - "chr10 17.844778299331665\n", - "chr10 25.784512758255005\n", - "chr11 30.30143165588379\n", - "chr11 33.256701707839966\n", - "chr12 37.791435956954956\n", - "chr12 40.85292291641235\n", - "chr13 44.619521141052246\n", - "chr13 47.792500495910645\n", - "chr14 51.4214243888855\n", - "chr14 53.6813702583313\n", - "chr15 56.946401834487915\n", - "chr15 59.10466551780701\n", - "chr16 61.939475774765015\n", - "chr16 63.999470472335815\n", - "chr17 66.63648653030396\n", - "chr17 68.4126443862915\n", - "chr18 71.05454993247986\n", - "chr18 72.90085673332214\n", - "chr19 74.78594756126404\n", - "chr19 76.80954170227051\n", - "chr2 85.25815343856812\n", - "chr2 95.36479425430298\n", - "chr20 97.74516272544861\n", - "chr20 99.27151441574097\n", - "chr21 100.82207584381104\n", - "chr21 103.02815318107605\n", - "chr22 104.63926863670349\n", - "chr22 106.02127361297607\n", - "chr3 112.71910071372986\n", - "chr3 117.30491018295288\n", - "chr4 123.77405095100403\n", - "chr4 128.67069339752197\n", - "chr5 134.89299392700195\n", - "chr5 138.83413815498352\n", - "chr6 144.83386087417603\n", - "chr6 149.115407705307\n", - "chr7 154.4929392337799\n", - "chr7 157.8094253540039\n", - "chr8 162.8749077320099\n", - "chr8 165.9331293106079\n", - "chr9 170.5435709953308\n", - "chr9 173.46287417411804\n", - "chrX 178.5410988330841\n", - "chrX 185.49569463729858\n", - "chrY 187.14469981193542\n", - "chrY 189.6306025981903\n", - "MCF-7 0.01819300651550293\n", - "chr1 8.266149282455444\n", - "chr1 13.86928129196167\n", - "chr10 18.216674327850342\n", - "chr10 20.975315809249878\n", - "chr11 25.302175998687744\n", - "chr11 34.40013885498047\n", - "chr12 38.70525503158569\n", - "chr12 41.59175777435303\n", - "chr13 45.130286693573\n", - "chr13 47.67305374145508\n", - "chr14 51.26033353805542\n", - "chr14 53.59153509140015\n", - "chr15 56.858047008514404\n", - "chr15 59.08759665489197\n", - "chr16 62.03992414474487\n", - "chr16 63.99170207977295\n", - "chr17 67.05595779418945\n", - "chr17 69.3644654750824\n", - "chr18 71.78018283843994\n", - "chr18 73.58044695854187\n", - "chr19 75.70175457000732\n", - "chr19 79.72573828697205\n", - "chr2 87.675612449646\n", - "chr2 92.91672372817993\n", - "chr20 95.51653027534485\n", - "chr20 96.88600373268127\n", - "chr21 98.43806076049805\n", - "chr21 103.25369572639465\n", - "chr22 104.84882092475891\n", - "chr22 106.21143817901611\n", - "chr3 112.67947244644165\n", - "chr3 116.70610451698303\n", - "chr4 122.56520342826843\n", - "chr4 126.52856135368347\n", - "chr5 132.38469552993774\n", - "chr5 136.28370690345764\n", - "chr6 141.5743978023529\n", - "chr6 145.10061717033386\n", - "chr7 150.44007444381714\n", - "chr7 155.55760312080383\n", - "chr8 160.3683557510376\n", - "chr8 163.43416213989258\n", - "chr9 167.90313267707825\n", - "chr9 172.0667405128479\n", - "chrX 176.69336795806885\n", - "chrX 181.83150935173035\n", - "K562 0.007167339324951172\n", - "chr1 8.471662998199463\n", - "chr1 13.464861631393433\n", - "chr10 17.858335494995117\n", - "chr10 20.700791835784912\n", - "chr11 25.168848276138306\n", - "chr11 28.01260733604431\n", - "chr12 32.38129758834839\n", - "chr12 35.250038385391235\n", - "chr13 38.72063398361206\n", - "chr13 43.30442762374878\n", - "chr14 46.55065989494324\n", - "chr14 51.87103271484375\n", - "chr15 55.08980083465576\n", - "chr15 57.35198903083801\n", - "chr16 60.444990396499634\n", - "chr16 62.56146717071533\n", - "chr17 65.33607196807861\n", - "chr17 75.77480912208557\n", - "chr18 78.25007915496826\n", - "chr18 82.4424319267273\n", - "chr19 84.73718905448914\n", - "chr19 86.0900673866272\n", - "chr2 93.6916708946228\n", - "chr2 98.61803960800171\n", - "chr20 100.70567536354065\n", - "chr20 102.18551921844482\n", - "chr21 103.75095820426941\n", - "chr21 104.96330642700195\n", - "chr22 106.666348695755\n", - "chr22 108.20869731903076\n", - "chr3 114.6058874130249\n", - "chr3 123.16646194458008\n", - "chr4 129.07538533210754\n", - "chr4 135.95439338684082\n", - "chr5 141.63543701171875\n", - "chr5 148.8255476951599\n", - "chr6 154.68585968017578\n", - "chr6 160.3087387084961\n", - "chr7 165.7410364151001\n", - "chr7 169.09255123138428\n", - "chr8 173.68864274024963\n", - "chr8 176.73100185394287\n", - "chr9 181.10383462905884\n", - "chr9 184.0267071723938\n", - "chrX 188.59823846817017\n", - "chrX 191.7538366317749\n" - ] - } - ], - "source": [ - "### import pyBigWig\n", - "import glob\n", - "\n", - "dnases = {}\n", - "celltypes = ['A549', 'GM12878', 'H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']\n", - "\n", - "for ctype in celltypes:#glob.glob('dnase_bigwigs/*'):\n", - " itime = time.time()\n", - " # ctype = pth.split('/')[1].split('.')[1]\n", - " if ctype not in ['liver', 'MCF-7', 'K562']:\n", - " continue\n", - " bw = pyBigWig.open(\"dnase_bigwigs/DNASE.{}.fc.signal.bigwig\".format(ctype))\n", - " chromsizes = bw.chroms()\n", - " print(ctype, time.time() - itime)\n", - " dn_dict = {}\n", - " for chrom in chromsizes: #chr_IDs:\n", - " x = bw.values(chrom, 0, chromsizes[chrom], numpy=True)\n", - " dn_dict[chrom] = np.nan_to_num(x).astype(np.float16) # half-precision makes things significantly smaller (less time to load)\n", - " print(chrom, time.time() - itime)\n", - " \n", - " np.save('dnase/{}/{}.npy'.format(ctype, chrom), dn_dict[chrom])\n", - " print(chrom, time.time() - itime)\n", - " dnases[ctype] = dn_dict\n", - "\n", - "for ctype in dnases:\n", - " itime = time.time()\n", - " print(ctype)\n", - " dn_dict = dnases[ctype]\n", - " \n", - " # Save as npz archive\n", - " np.savez_compressed('codalab_archive/{}_dnase'.format(ctype), **dn_dict)\n", - " print(time.time() - itime)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.5" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/dataset_preprocessing/encode-tfbs/prep_metadata_labels.py b/dataset_preprocessing/encode-tfbs/prep_metadata_labels.py deleted file mode 100644 index 00a8b5c3..00000000 --- a/dataset_preprocessing/encode-tfbs/prep_metadata_labels.py +++ /dev/null @@ -1,146 +0,0 @@ -import os, csv -import scipy, numpy as np, pandas as pd, time -from scipy import sparse -import pyBigWig - -# Human chromosome names -chr_IDs = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', - 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', - 'chr20', 'chr21', 'chr22', 'chrX'] -chrom_sizes = {'chr1': 249250621, 'chr10': 135534747, 'chr11': 135006516, 'chr12': 133851895, 'chr13': 115169878, 'chr14': 107349540, 'chr15': 102531392, 'chr16': 90354753, 'chr17': 81195210, 'chr18': 78077248, 'chr19': 59128983, 'chr2': 243199373, 'chr20': 63025520, 'chr21': 48129895, 'chr22': 51304566, 'chr3': 198022430, 'chr4': 191154276, 'chr5': 180915260, 'chr6': 171115067, 'chr7': 159138663, 'chr8': 146364022, 'chr9': 141213431, 'chrX': 155270560} - -_data_dir = '../../examples/data/encode-tfbs_v1.0/' - - -def write_label_bigwigs( - celltypes, - train_suffix='train.labels.tsv.gz', - val_suffix='val.labels.tsv.gz', - tf_name='MAX' -): - itime = time.time() - - # Read in metadata dataframe from training+validation data - train_regions_labeled = pd.read_csv(os.path.join(_data_dir, 'labels/{}.{}'.format(tf_name, train_suffix)), sep='\t') - val_regions_labeled = pd.read_csv(os.path.join(_data_dir, 'labels/{}.{}'.format(tf_name, val_suffix)), sep='\t') - training_df = train_regions_labeled - val_df = val_regions_labeled - all_df = pd.concat([training_df, val_df]) - - # Get the y values, and remove negative labels by default. - pd_list = [] - for ct in celltypes: - tc_chr = all_df[['chr', 'start', 'stop', ct]] - tc_chr.columns = ['chr', 'start', 'stop', 'y'] - tc_chr = tc_chr[tc_chr['y'] != 'U'] - tc_chr['y'] = tc_chr['y'].replace({'U': 0, 'B': 1, 'A': 0.5}).values - - tc_chr.insert(len(tc_chr.columns), 'celltype', ct) - pd_list.append(tc_chr) - print(ct, time.time() - itime) - _metadata_df = pd.concat(pd_list) - - print(time.time() - itime) - _unsorted_dir = _data_dir + 'labels/{}/{}_posamb.bed'.format( - tf_name, tf_name) - _sorted_dir = _unsorted_dir.replace( - '{}_posamb'.format(tf_name), - '{}_posamb.sorted'.format(tf_name) - ) - _metadata_df.to_csv( - _unsorted_dir, sep='\t', header=False, index=False - ) - print(time.time() - itime) - - # Sort bigwigs (as bed files) in order to convert to bigwig. - os.system('sort -k1,1 -k2,2n {} > {}'.format(_unsorted_dir, _sorted_dir)) - mdf_posamb = pd.read_csv( - _sorted_dir, - sep='\t', header=None, index_col=None, names=['chr', 'start', 'stop', 'y', 'celltype'] - ) - - # Write the binned labels to bigwig files - genome-wide labels - chromsizes_list = [(k, v) for k, v in chrom_sizes.items()] - for ct in celltypes: - ct_labels_bw_path = _data_dir + "labels/{}/{}_{}.bigwig".format( - tf_name, tf_name, ct) - df = mdf_posamb[mdf_posamb['celltype'] == ct] - bw = pyBigWig.open(ct_labels_bw_path, "w") - bw.addHeader(chromsizes_list) - bw.addEntries(list(df['chr']), list(df['start']), ends=list(df['start']+50), values=list(df['y'])) - print(ct, time.time() - itime) - bw.close() - - -def write_metadata_products( - celltypes, - bed_df_filename='metadata_df.bed', - y_arr_filename='metadata_y.npy', - stride=6400, - tf_name='MAX', - posamb_only=False -): - itime = time.time() - celltype_mdta = [] - celltype_labels = [] - if posamb_only: - mdf_posamb = pd.read_csv( - _data_dir + 'labels/{}/{}_posamb.sorted.bed'.format(tf_name, tf_name), - sep='\t', header=None, index_col=None, names=['chr', 'start', 'stop', 'y', 'celltype'] - ) - # Retrieve only the windows containing positively/ambiguously labeled bins (if posamb_only==True), or all windows (if posamb_only==False). - for ct in celltypes: - ct_labels_bw_path = _data_dir + "labels/{}/{}_{}.bigwig".format(tf_name, tf_name, ct) - df_construction = [] - mdta_labels = [] - bw = pyBigWig.open(ct_labels_bw_path) - if posamb_only: # Retrieve only the windows containing positively/ambiguously labeled bins - df = mdf_posamb[mdf_posamb['celltype'] == ct] - df['window_start'] = stride*(df['start'] // stride) - uniq_windows = np.unique(["{}:{}".format(x[0], x[1]) for x in zip(df['chr'], df['window_start'])]) - for u in uniq_windows: - u_chr = u.split(':')[0] - u_start = int(u.split(':')[1]) - u_end = u_start + stride - x = np.nan_to_num(bw.values(u_chr, u_start, u_end, numpy=True)) - df_construction.append((u_chr, u_start, u_end)) - mdta_labels.append(x[np.arange(0, len(x), 50)]) - else: # Retrieve all windows genome-wide - for chrID in bw.chroms(): - chromsize = bw.chroms()[chrID] - # Iterate over windows - for startc in np.arange(int(stride/2), chromsize-(2*stride), stride): - u_end = startc + stride - if u_end > chromsize: - break - x = np.nan_to_num(bw.values(chrID, startc, u_end, numpy=True)) - df_construction.append((chrID, startc, u_end)) - mdta_labels.append(x[np.arange(0, len(x), 50)]) - print(ct, chrID, time.time() - itime) - celltype_mdta_df = pd.DataFrame(df_construction, columns=['chr', 'start', 'stop']) - celltype_mdta_df.insert(len(celltype_mdta_df.columns), 'celltype', ct) - celltype_mdta.append(celltype_mdta_df) - celltype_labels.append(np.stack(mdta_labels)) - print(ct, time.time() - itime) - bw.close() - print(time.time() - itime) - - all_metadata_df = pd.concat(celltype_mdta) - all_metadata_df.to_csv( - _data_dir + 'labels/{}/{}'.format(tf_name, bed_df_filename), - sep='\t', header=False, index=False - ) - np.save(_data_dir + 'labels/{}/{}'.format(tf_name, y_arr_filename), np.vstack(celltype_labels)) - - -if __name__ == '__main__': - tf_name = 'JUND' - tfs_to_celltypes = { - 'MAX': ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562', 'A549', 'GM12878', 'liver'], - 'REST': ['H1-hESC', 'HeLa-S3', 'HepG2', 'MCF-7', 'Panc1', 'liver'], - 'JUND': ['HCT116', 'HeLa-S3', 'HepG2', 'K562', 'MCF-7', 'liver'] - } - all_celltypes = tfs_to_celltypes[tf_name] - write_label_bigwigs([x for x in all_celltypes if x != 'liver'], tf_name=tf_name) - write_label_bigwigs(['liver'], train_suffix='train_wc.labels.tsv.gz', val_suffix='test.labels.tsv.gz', tf_name=tf_name) - write_metadata_products(all_celltypes, tf_name=tf_name) diff --git a/dataset_preprocessing/encode-tfbs/prep_sequence.py b/dataset_preprocessing/encode-tfbs/prep_sequence.py deleted file mode 100644 index b80be0da..00000000 --- a/dataset_preprocessing/encode-tfbs/prep_sequence.py +++ /dev/null @@ -1,131 +0,0 @@ -import argparse, time -import numpy as np - -from tqdm import tqdm - -# Sequence preprocessing. Code adapted from Jacob Schreiber. - -# Human chromosome names -chr_IDs = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', - 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', - 'chr20', 'chr21', 'chr22', 'chrX'] - -def one_hot_encode(sequence, ignore='N', alphabet=None, dtype='int8', verbose=False, **kwargs): - """ - Converts a string or list of characters into a one-hot encoding. - This function will take in either a string or a list and convert it into a one-hot encoding. If the input is a string, each character is assumed to be a different symbol, e.g. 'ACGT' is assumed to be a sequence of four characters. If the input is a list, the elements can be any size. - Although this function will be used here primarily to convert nucleotide sequences into one-hot encoding with an alphabet of size 4, in principle this function can be used for any types of sequences. - - Parameters - ---------- - sequence : str or list - The sequence to convert to a one-hot encoding. - ignore : str, optional - A character to indicate setting nothing to 1 for that row, keeping the encoding entirely 0's for that row. In the context of genomics, this is the N character. Default is 'N'. - alphabet : set or tuple or list, optional - A pre-defined alphabet. If None is passed in, the alphabet will be determined from the sequence, but this may be time consuming for large sequences. Default is None. - dtype : str or numpy.dtype, optional - The data type of the returned encoding. Default is int8. - verbose : bool or str, optional - Whether to display a progress bar. If a string is passed in, use as the name of the progressbar. Default is False. - kwargs : arguments - Arguments to be passed into tqdm. Default is None. - - Returns - ------- - ohe : numpy.ndarray - A binary matrix of shape (alphabet_size, sequence_length) where alphabet_size is the number of unique elements in the sequence and sequence_length is the length of the input sequence. - """ - - name = None if verbose in (True, False) else verbose - d = verbose is False - - if isinstance(sequence, str): - sequence = list(sequence) - - alphabet = alphabet or np.unique(sequence) - alphabet = [char for char in alphabet if char != ignore] - alphabet_lookup = {char: i for i, char in enumerate(alphabet)} - - ohe = np.zeros((len(sequence), len(alphabet)), dtype=dtype) - for i, char in tqdm(enumerate(sequence), disable=d, desc=name, **kwargs): - if char != ignore: - idx = alphabet_lookup[char] - ohe[i, idx] = 1 - - return ohe - - -def read_fasta(filename, include_chroms=None, exclude_chroms=None, ignore='N', alphabet=['A', 'C', 'G', 'T', 'N'], verbose=True): - """ - Read in a FASTA file and output a dictionary of sequences. - This function will take in the path to a FASTA-formatted file and output a string containing the sequence for each chromosome. Optionally, the user can specify a set of chromosomes to include or exclude from the returned dictionary. - - Parameters - ---------- - filename : str - The path to the FASTA-formatted file to open. - include_chroms : set or tuple or list, optional - The exact names of chromosomes in the FASTA file to include, excluding all others. If None, include all chromosomes (except those specified by exclude_chroms). Default is None. - exclude_chroms : set or tuple or list, optional - The exact names of chromosomes in the FASTA file to exclude, including all others. If None, include all chromosomes (or the set specified by include_chroms). Default is None. - ignore : str, optional - A character to indicate setting nothing to 1 for that row, keeping the encoding entirely 0's for that row. In the context of genomics, this is the N character. Default is 'N'. - alphabet : set or tuple or list, optional - A pre-defined alphabet. If None is passed in, the alphabet will be determined from the sequence, but this may be time consuming for large sequences. Must include the ignore character. Default is ['A', 'C', 'G', 'T', 'N']. - verbose : bool or str, optional - Whether to display a progress bar. If a string is passed in, use as the name of the progressbar. Default is False. - - Returns - ------- - chroms : dict - A dictionary of strings where the keys are the names of the chromosomes (exact strings from the header lines in the FASTA file) and the values are the strings encoded there. - """ - - sequences = {} - name, sequence = None, None - skip_chrom = False - - with open(filename, "r") as infile: - for line in tqdm(infile, disable=not verbose): - if line.startswith(">"): - if name is not None and skip_chrom is False: - sequences[name] = ''.join(sequence) - sequence = [] - name = line[1:].strip("\n") - if include_chroms is not None and name not in include_chroms: - skip_chrom = True - elif exclude_chroms is not None and name in exclude_chroms: - skip_chrom = True - else: - skip_chrom = False - else: - if skip_chrom == False: - sequence.append(line.rstrip("\n").upper()) - return sequences - - -def generate_sequence_archive(seq_path='sequence/hg19.genome.fa', output_dir): - fasta_contents = read_fasta() - kw_dict = {} - itime = time.time() - for chrom in chr_IDs: - seqstr = fasta_contents[chrom] - kw_dict[chrom] = one_hot_encode(seqstr, alphabet=['A', 'C', 'G', 'T', 'N']) - print(chrom, time.time() - itime) - - # Save as npz archive; can take several (>20) minutes - print("Saving npz archive...") - np.savez_compressed('{}/sequence'.format(output_root), **kw_dict) - print(time.time() - itime) - - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('--seq_path', required=True) - parser.add_argument('--output_dir', required=True) - args = parser.parse_args() - - generate_sequence_archive( - seq_path=args.seq_path, - output_dir=args.output_dir) diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index 052d9629..47e502a8 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -105,7 +105,7 @@ }, 'process_outputs_function': 'multiclass_logits_to_pred', }, - 'encode-tfbs': { + 'encode': { 'split_scheme': 'official', 'model': 'unet-seq', 'model_kwargs': {'n_channels_in': 5}, diff --git a/wilds/__init__.py b/wilds/__init__.py index 7a9ed877..a5290478 100644 --- a/wilds/__init__.py +++ b/wilds/__init__.py @@ -9,7 +9,7 @@ 'ogb-molpcba', 'poverty', 'fmow', - 'py150', + 'py150', ] additional_datasets = [ @@ -18,7 +18,7 @@ 'yelp', 'bdd100k', 'sqf', - 'encode-tfbs' + 'encode' ] supported_datasets = benchmark_datasets + additional_datasets diff --git a/wilds/datasets/encodetfbs_dataset.py b/wilds/datasets/encodetfbs_dataset.py deleted file mode 100644 index 9fe44e98..00000000 --- a/wilds/datasets/encodetfbs_dataset.py +++ /dev/null @@ -1,430 +0,0 @@ -import os, time -import torch -import pandas as pd -import numpy as np -import pyBigWig -from wilds.datasets.wilds_dataset import WILDSDataset -from wilds.common.utils import subsample_idxs -from wilds.common.grouper import CombinatorialGrouper -from wilds.common.metrics.all_metrics import MultiTaskAveragePrecision - -# Human chromosomes in hg19 -chrom_sizes = {'chr1': 249250621, 'chr10': 135534747, 'chr11': 135006516, 'chr12': 133851895, 'chr13': 115169878, 'chr14': 107349540, 'chr15': 102531392, 'chr16': 90354753, 'chr17': 81195210, 'chr18': 78077248, 'chr19': 59128983, 'chr2': 243199373, 'chr20': 63025520, 'chr21': 48129895, 'chr22': 51304566, 'chr3': 198022430, 'chr4': 191154276, 'chr5': 180915260, 'chr6': 171115067, 'chr7': 159138663, 'chr8': 146364022, 'chr9': 141213431, 'chrX': 155270560} - -# quantile normalization via numpy inter/extra-polation -def anchor(input_data, sample, ref): # input 1d array - sample.sort() - ref.sort() - # 0. create the mapping function - index = np.array(np.where(np.diff(sample) != 0)) + 1 - index = index.flatten() - x = np.concatenate((np.zeros(1), sample[index])) # domain - y = np.zeros(len(x)) # codomain - for i in np.arange(0,len(index)-1, 1): - start = index[i] - end = index[i+1] - y[i+1] = np.mean(ref[start:end]) - i += 1 - start = index[i] - end = len(ref) - y[i+1] = np.mean(ref[start:end]) - # 1. interpolate - output = np.interp(input_data, x, y) - # 2. extrapolate - degree = 1 # degree of the fitting polynomial - num = 10 # number of positions for extrapolate - f1 = np.poly1d(np.polyfit(sample[-num:],ref[-num:],degree)) -# f2=np.poly1d(np.polyfit(sample[:num],ref[:num],degree)) - output[input_data > sample[-1]] = f1(input_data[input_data > sample[-1]]) -# output[input_data.' - self._transcription_factor = 'MAX' - if 'tf.' in split_scheme: - tkns = split_scheme.split('.') - self._transcription_factor = tkns[1] - split_scheme = '.'.join(tkns[2:]) - self._split_scheme = split_scheme - - train_celltypes = official_train_cts[self._transcription_factor] - val_celltype = official_val_cts[self._transcription_factor] - test_celltype = official_test_cts[self._transcription_factor] - - if self._split_scheme == 'official': - splits = { - 'train': { - 'chroms': train_chroms, - 'celltypes': train_celltypes - }, - 'id_val': { - 'chroms': val_chroms, - 'celltypes': train_celltypes - }, - 'val': { - 'chroms': val_chroms, - 'celltypes': val_celltype - }, - 'test': { - 'chroms': test_chroms, - 'celltypes': test_celltype - }, - 'id_test': { - 'chroms': test_chroms, - 'celltypes': train_celltypes - } - } - self._split_dict = { - 'train': 0, - 'val': 1, - 'test': 2, - 'id_val': 3, - 'id_test': 4 - } - self._split_names = { - 'train': 'Train', - 'val': 'Validation (OOD)', - 'test': 'Test', - 'id_val': 'Validation (ID)', - 'id_test': 'Test (ID)', - } - elif self._split_scheme == 'in-dist': - splits = { - 'train': { - 'chroms': train_chroms, - 'celltypes': test_celltype, - }, - 'val': { - 'chroms': val_chroms, - 'celltypes': test_celltype - }, - 'test': { - 'chroms': test_chroms, - 'celltypes': test_celltype - }, - } - self._split_dict = { - 'train': 0, - 'val': 1, - 'test': 2, - } - self._split_names = { - 'train': 'Train', - 'val': 'Validation (OOD)', - 'test': 'Test', - } - elif 'id-' in self._split_scheme: - test_celltype = [ self._split_scheme.split('id-')[1] ] - splits = { - 'train': { - 'chroms': train_chroms, - 'celltypes': test_celltype, - }, - 'val': { - 'chroms': val_chroms, - 'celltypes': test_celltype - }, - 'test': { - 'chroms': test_chroms, - 'celltypes': test_celltype - }, - } - self._split_dict = { - 'train': 0, - 'val': 1, - 'test': 2, - } - self._split_names = { - 'train': 'Train', - 'val': 'Validation (OOD)', - 'test': 'Test', - } - - # Add new split scheme specifying custom test and val celltypes in the format val..test., e.g. self._split_scheme == 'official' is equivalent to self._split_scheme == 'val.HepG2.test.liver' - elif '.' in self._split_scheme: - all_celltypes = train_celltypes + val_celltype + test_celltype - in_val_ct = self._split_scheme.split('.')[1] - in_test_ct = self._split_scheme.split('.')[3] - train_celltypes = [ct for ct in all_celltypes if ((ct != in_val_ct) and (ct != in_test_ct))] - val_celltype = [in_val_ct] - test_celltype = [in_test_ct] - splits = { - 'train': { - 'chroms': train_chroms, - 'celltypes': train_celltypes - }, - 'id_val': { - 'chroms': val_chroms, - 'celltypes': train_celltypes - }, - 'val': { - 'chroms': val_chroms, - 'celltypes': val_celltype - }, - 'test': { - 'chroms': test_chroms, - 'celltypes': test_celltype - }, - 'id_test': { - 'chroms': test_chroms, - 'celltypes': train_celltypes - } - } - self._split_dict = { - 'train': 0, - 'val': 1, - 'test': 2, - 'id_val': 3, - 'id_test': 4 - } - self._split_names = { - 'train': 'Train', - 'val': 'Validation (OOD)', - 'test': 'Test', - 'id_val': 'Validation (ID)', - 'id_test': 'Test (ID)', - } - else: - raise ValueError(f'Split scheme {self._split_scheme} not recognized') - - # Read in metadata and labels - self._metadata_df = pd.read_csv( - self._data_dir + '/labels/{}/metadata_df.bed'.format(self._transcription_factor), - sep='\t', header=None, - index_col=None, names=['chr', 'start', 'stop', 'celltype'] - ) - self._y_array = torch.tensor(np.load( - self._data_dir + '/labels/{}/metadata_y.npy'.format(self._transcription_factor))) - - # ~10% of the dataset has ambiguous labels - # i.e., we can't tell if there is a binding event or not. - # This typically happens at the flanking regions of peaks. - # For our purposes, we will ignore these ambiguous labels during training and eval. - self.y_array[self.y_array == 0.5] = float('nan') - - self._split_array = -1 * np.ones(self._metadata_df.shape[0]).astype(int) - for split, d in splits.items(): - chrom_mask = np.isin(self._metadata_df['chr'], d['chroms']) - celltype_mask = np.isin(self._metadata_df['celltype'], d['celltypes']) - self._split_array[chrom_mask & celltype_mask] = self._split_dict[split] - - keep_mask = (self._split_array != -1) - - # Remove all-zero sequences from training. - train_mask = (self._split_array == self._split_dict['train']) - allzeroes_mask = (self._y_array.sum(axis=1) == 0).numpy() - keep_mask = keep_mask & ~(train_mask & allzeroes_mask) - - # Subsample the testing and validation indices, to speed up evaluation. - # For the OOD splits (val and test), we subsample by a factor of 3 - # For the id_val and id_test splits, we subsample by a factor of 3*(# of training celltypes) - for subsample_seed, (split, subsample_factor) in enumerate([ - ('val', 3), - ('test', 3), - ('id_val', 3*len(splits['train']['celltypes'])), - ('id_test', 3*len(splits['train']['celltypes']))]): - if split not in self._split_dict: continue - split_mask = (self._split_array == self._split_dict[split]) - split_idxs = np.arange(len(self._split_array))[split_mask] - idxs_to_remove = subsample_idxs( - split_idxs, - num=len(split_idxs) // subsample_factor, - seed=subsample_seed, - take_rest=True) - keep_mask[idxs_to_remove] = False - - self._metadata_df = self._metadata_df[keep_mask] - self._split_array = self._split_array[keep_mask] - self._y_array = self._y_array[keep_mask] - - self._all_chroms = sorted(list({chrom for _, d in splits.items() for chrom in d['chroms']})) - self._all_celltypes = sorted(list({chrom for _, d in splits.items() for chrom in d['celltypes']})) - - # Load sequence into memory - sequence_filename = os.path.join(self._data_dir, 'sequence.npz') - seq_arr = np.load(sequence_filename) - self._seq_bp = {} - for chrom in self._all_chroms: - self._seq_bp[chrom] = seq_arr[chrom] - print(chrom, time.time() - itime) - del seq_arr - - # Set up file handles for DNase features, writing normalized DNase tracks along the way if they aren't already written. - self._dnase_allcelltypes = {} - for ct in self._all_celltypes: - orig_dnase_bw_path = os.path.join(self._data_dir, 'DNASE.{}.fc.signal.bigwig'.format(ct)) - dnase_bw_path = os.path.join(self._data_dir, 'DNase.{}.{}.{}.bigwig'.format(self._transcription_factor, ct, self._split_scheme)) - if not os.path.exists(dnase_bw_path): - ref_celltypes = splits['train']['celltypes'] - dnase_normalize(ct, ref_celltypes, out_fname=dnase_bw_path, data_pfx=self._data_dir) - self._dnase_allcelltypes[ct] = pyBigWig.open(dnase_bw_path) - - # Load subsampled DNase arrays for normalization purposes - self._dnase_qnorm_arrays = {} - for ct in self._all_celltypes: - qnorm_arr_path = os.path.join(self._data_dir, 'qn.{}.npy'.format(ct)) - self._dnase_qnorm_arrays[ct] = np.load(qnorm_arr_path) - self._norm_ref_distr = np.zeros(len(self._dnase_qnorm_arrays[ct])) - test_cts = splits['test']['celltypes'] - num_to_avg = len(self._all_celltypes) - len(test_cts) - for ct in self._all_celltypes: - if ct not in test_cts: - self._norm_ref_distr += (1.0/num_to_avg)*self._dnase_qnorm_arrays[ct] - - # Set up metadata fields, map, array - self._metadata_fields = ['chr', 'celltype'] - self._metadata_map = {} - self._metadata_map['chr'] = self._all_chroms - self._metadata_map['celltype'] = self._all_celltypes - chr_ints = self._metadata_df['chr'].replace(dict( [(y, x) for x, y in enumerate(self._metadata_map['chr'])] )).values - celltype_ints = self._metadata_df['celltype'].replace(dict( [(y, x) for x, y in enumerate(self._metadata_map['celltype'])] )).values - self._metadata_array = torch.stack( - (torch.LongTensor(chr_ints), - torch.LongTensor(celltype_ints) - ), - dim=1) - - self._eval_grouper = CombinatorialGrouper( - dataset=self, - groupby_fields=['celltype']) - - self._metric = MultiTaskAveragePrecision() - - super().__init__(root_dir, download, split_scheme) - - def get_input(self, idx, window_size=12800): - """ - Returns x for a given idx in metadata_array, which has been filtered to only take windows with the desired stride. - Computes this from: - (1) sequence features in self._seq_bp - (2) DNase bigwig file handles in self._dnase_allcelltypes - (3) Metadata for the index (location along the genome with 6400bp window width) - (4) Window_size, the length of sequence returned (centered on the 6400bp region in (3)) - """ - this_metadata = self._metadata_df.iloc[idx, :] - chrom = this_metadata['chr'] - interval_start = this_metadata['start'] - int(window_size/4) - interval_end = interval_start + window_size - seq_this = self._seq_bp[this_metadata['chr']][interval_start:interval_end] - dnase_bw = self._dnase_allcelltypes[this_metadata['celltype']] - dnase_this = np.nan_to_num(dnase_bw.values(chrom, interval_start, interval_end, numpy=True)) - return torch.tensor(np.column_stack( - [seq_this, - dnase_this] - ).T) - - def eval(self, y_pred, y_true, metadata): - return self.standard_group_eval( - self._metric, - self._eval_grouper, - y_pred, y_true, metadata) diff --git a/wilds/get_dataset.py b/wilds/get_dataset.py index a67ba538..14a8d9c7 100644 --- a/wilds/get_dataset.py +++ b/wilds/get_dataset.py @@ -78,6 +78,6 @@ def get_dataset(dataset, version=None, **dataset_kwargs): from wilds.datasets.sqf_dataset import SQFDataset return SQFDataset(version=version, **dataset_kwargs) - elif dataset == 'encode-tfbs': - from wilds.datasets.encodetfbs_dataset import EncodeTFBSDataset - return EncodeTFBSDataset(version=version, **dataset_kwargs) + elif dataset == 'encode': + from wilds.datasets.encode_dataset import EncodeDataset + return EncodeDataset(version=version, **dataset_kwargs) From 69ba0f2545a41afd24772c836e4a5db0ee4a2471 Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Tue, 6 Jul 2021 22:15:53 -0700 Subject: [PATCH 213/244] encode files with new names --- dataset_preprocessing/encode/README.md | 26 ++ .../encode/prep_accessibility.py | 180 ++++++++ .../encode/prep_datasets.ipynb | 279 ++++++++++++ .../encode/prep_metadata_labels.py | 146 ++++++ dataset_preprocessing/encode/prep_sequence.py | 131 ++++++ wilds/datasets/encode_dataset.py | 425 ++++++++++++++++++ 6 files changed, 1187 insertions(+) create mode 100644 dataset_preprocessing/encode/README.md create mode 100644 dataset_preprocessing/encode/prep_accessibility.py create mode 100644 dataset_preprocessing/encode/prep_datasets.ipynb create mode 100644 dataset_preprocessing/encode/prep_metadata_labels.py create mode 100644 dataset_preprocessing/encode/prep_sequence.py create mode 100644 wilds/datasets/encode_dataset.py diff --git a/dataset_preprocessing/encode/README.md b/dataset_preprocessing/encode/README.md new file mode 100644 index 00000000..1d045e07 --- /dev/null +++ b/dataset_preprocessing/encode/README.md @@ -0,0 +1,26 @@ +## ENCODE feature generation and preprocessing + +#### Requirements +- pyBigWig + +#### Instructions to create Codalab bundle + +1. Download the human genome sequence (hg19 assembly) in FASTA format from http://hgdownload.cse.ucsc.edu/goldenpath/hg19/bigZips/hg19.fa.gz and extract it into `SEQUENCE_PATH`. + +2. Run `python prep_sequence.py --seq_path SEQUENCE_PATH --output_dir OUTPUT_DIR` to write the fasta file found in `SEQUENCE_PATH` to a numpy array archive in `OUTPUT_DIR`. + +3. Download the DNase accessibility data. This consists of whole-genome DNase files in bigwig format from https://guanfiles.dcmb.med.umich.edu/Leopard/dnase_bigwig/. These are saved with filename `DNASE..fc.signal.bigwig`. + +4. Run `prep_accessibility.py`. + +5. Download the labels from the challenge into a label directory `labels/` created for this purpose: + - The training chromosome labels for the challenge's training cell types from https://www.synapse.org/#!Synapse:syn7413983 for the relevant transcription factor ( https://www.synapse.org/#!Synapse:syn7415202 for the TF MAX, , downloaded as MAX.train.labels.tsv.gz ). + - The training chromosome labels for the challenge's evaluation cell type (liver) from https://www.synapse.org/#!Synapse:syn8077511 for the relevant transcription factor ( https://www.synapse.org/#!Synapse:syn8077648 for the TF MAX, downloaded as MAX.train_wc.labels.tsv.gz ). + - The validation chromosome labels for the challenge's training cell types from https://www.synapse.org/#!Synapse:syn8441154 for the relevant transcription factor ( https://www.synapse.org/#!Synapse:syn8442103 for the TF MAX, downloaded as MAX.val.labels.tsv.gz ). + - The validation chromosome labels for the challenge's evaluation cell type (liver) from https://www.synapse.org/#!Synapse:syn8442975 for the relevant transcription factor ( https://www.synapse.org/#!Synapse:syn8443021 for the TF MAX, downloaded as MAX.test.labels.tsv.gz ). + +6. Run `prep_metadata_labels.py`. + + +#### Instructions to run on Codalab bundle +7. diff --git a/dataset_preprocessing/encode/prep_accessibility.py b/dataset_preprocessing/encode/prep_accessibility.py new file mode 100644 index 00000000..a3303317 --- /dev/null +++ b/dataset_preprocessing/encode/prep_accessibility.py @@ -0,0 +1,180 @@ +# Adapted from https://github.com/GuanLab/Leopard/blob/master/data/quantile_normalize_bigwig.py + +import argparse, time +import numpy as np +import pyBigWig + +# Human chromosomes in hg19 +chrom_sizes = {'chr1': 249250621, 'chr10': 135534747, 'chr11': 135006516, 'chr12': 133851895, 'chr13': 115169878, 'chr14': 107349540, 'chr15': 102531392, 'chr16': 90354753, 'chr17': 81195210, 'chr18': 78077248, 'chr19': 59128983, 'chr2': 243199373, 'chr20': 63025520, 'chr21': 48129895, 'chr22': 51304566, 'chr3': 198022430, 'chr4': 191154276, 'chr5': 180915260, 'chr6': 171115067, 'chr7': 159138663, 'chr8': 146364022, 'chr9': 141213431, 'chrX': 155270560} + + +def qn_sample_to_array( + input_celltypes, + input_chroms=None, + subsampling_ratio=1000, + data_pfx = '/users/abalsubr/wilds/examples/data/encode_v1.0/' +): + itime = time.time() + if input_chroms is None: + input_chroms = chrom_sizes.keys() + qn_chrom_sizes = { k: chrom_sizes[k] for k in input_chroms } + # chromosome-specific subsampling seeds + chr_to_seed = {} + i = 0 + for the_chr in qn_chrom_sizes: + chr_to_seed[the_chr] = i + i += 1 + + # subsampling; multiple replicates are added + sample_len = np.ceil(np.array(list(qn_chrom_sizes.values()))/subsampling_ratio).astype(int) + sample = np.zeros(sum(sample_len)) + start = 0 + j = 0 + for the_chr in qn_chrom_sizes: + np.random.seed(chr_to_seed[the_chr]) + for ct in input_celltypes: + path = data_pfx + 'DNASE.{}.fc.signal.bigwig'.format(ct) + bw = pyBigWig.open(path) + signal = np.nan_to_num(np.array(bw.values(the_chr, 0, qn_chrom_sizes[the_chr]))) + index = np.random.randint(0, len(signal), sample_len[j]) + sample[start:(start+sample_len[j])] += (1.0/len(input_celltypes))*signal[index] + start += sample_len[j] + j += 1 + print(the_chr, ct, time.time() - itime) + + if np.any(np.isnan(sample)): + print('wtf! sample contains nan!') + sample.sort() + np.save(data_pfx + "qn.{}.npy".format('.'.join(input_celltypes)), sample) + + +# quantile normalization via numpy inter/extra-polation +def anchor(input_data, sample, ref): # input 1d array + sample.sort() + ref.sort() + # 0. create the mapping function + index = np.array(np.where(np.diff(sample) != 0)) + 1 + index = index.flatten() + x = np.concatenate((np.zeros(1), sample[index])) # domain + y = np.zeros(len(x)) # codomain + for i in np.arange(0,len(index)-1, 1): + start = index[i] + end = index[i+1] + y[i+1] = np.mean(ref[start:end]) + i += 1 + start = index[i] + end = len(ref) + y[i+1] = np.mean(ref[start:end]) + # 1. interpolate + output = np.interp(input_data, x, y) + # 2. extrapolate + degree = 1 # degree of the fitting polynomial + num = 10 # number of positions for extrapolate + f1 = np.poly1d(np.polyfit(sample[-num:],ref[-num:],degree)) +# f2=np.poly1d(np.polyfit(sample[:num],ref[:num],degree)) + output[input_data > sample[-1]] = f1(input_data[input_data > sample[-1]]) +# output[input_data20) minutes\n", + "print(\"Saving npz archive...\")\n", + "np.savez_compressed('codalab_archive/sequence', **kw_dict)\n", + "print(time.time() - itime)\n", + "\n", + "# # Save as npy arrays\n", + "# itime = time.time()\n", + "# for chrom in kw_dict:\n", + "# np.save('sequence/{}.npy'.format(chrom), kw_dict[chrom])\n", + "# print(chrom, time.time() - itime)\n", + "\n", + "npz_archive = np.load('codalab_archive/sequence.npz')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## DNase" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "liver 0.006468534469604492\n", + "chr1 8.260387659072876\n", + "chr1 13.276052474975586\n", + "chr10 17.844778299331665\n", + "chr10 25.784512758255005\n", + "chr11 30.30143165588379\n", + "chr11 33.256701707839966\n", + "chr12 37.791435956954956\n", + "chr12 40.85292291641235\n", + "chr13 44.619521141052246\n", + "chr13 47.792500495910645\n", + "chr14 51.4214243888855\n", + "chr14 53.6813702583313\n", + "chr15 56.946401834487915\n", + "chr15 59.10466551780701\n", + "chr16 61.939475774765015\n", + "chr16 63.999470472335815\n", + "chr17 66.63648653030396\n", + "chr17 68.4126443862915\n", + "chr18 71.05454993247986\n", + "chr18 72.90085673332214\n", + "chr19 74.78594756126404\n", + "chr19 76.80954170227051\n", + "chr2 85.25815343856812\n", + "chr2 95.36479425430298\n", + "chr20 97.74516272544861\n", + "chr20 99.27151441574097\n", + "chr21 100.82207584381104\n", + "chr21 103.02815318107605\n", + "chr22 104.63926863670349\n", + "chr22 106.02127361297607\n", + "chr3 112.71910071372986\n", + "chr3 117.30491018295288\n", + "chr4 123.77405095100403\n", + "chr4 128.67069339752197\n", + "chr5 134.89299392700195\n", + "chr5 138.83413815498352\n", + "chr6 144.83386087417603\n", + "chr6 149.115407705307\n", + "chr7 154.4929392337799\n", + "chr7 157.8094253540039\n", + "chr8 162.8749077320099\n", + "chr8 165.9331293106079\n", + "chr9 170.5435709953308\n", + "chr9 173.46287417411804\n", + "chrX 178.5410988330841\n", + "chrX 185.49569463729858\n", + "chrY 187.14469981193542\n", + "chrY 189.6306025981903\n", + "MCF-7 0.01819300651550293\n", + "chr1 8.266149282455444\n", + "chr1 13.86928129196167\n", + "chr10 18.216674327850342\n", + "chr10 20.975315809249878\n", + "chr11 25.302175998687744\n", + "chr11 34.40013885498047\n", + "chr12 38.70525503158569\n", + "chr12 41.59175777435303\n", + "chr13 45.130286693573\n", + "chr13 47.67305374145508\n", + "chr14 51.26033353805542\n", + "chr14 53.59153509140015\n", + "chr15 56.858047008514404\n", + "chr15 59.08759665489197\n", + "chr16 62.03992414474487\n", + "chr16 63.99170207977295\n", + "chr17 67.05595779418945\n", + "chr17 69.3644654750824\n", + "chr18 71.78018283843994\n", + "chr18 73.58044695854187\n", + "chr19 75.70175457000732\n", + "chr19 79.72573828697205\n", + "chr2 87.675612449646\n", + "chr2 92.91672372817993\n", + "chr20 95.51653027534485\n", + "chr20 96.88600373268127\n", + "chr21 98.43806076049805\n", + "chr21 103.25369572639465\n", + "chr22 104.84882092475891\n", + "chr22 106.21143817901611\n", + "chr3 112.67947244644165\n", + "chr3 116.70610451698303\n", + "chr4 122.56520342826843\n", + "chr4 126.52856135368347\n", + "chr5 132.38469552993774\n", + "chr5 136.28370690345764\n", + "chr6 141.5743978023529\n", + "chr6 145.10061717033386\n", + "chr7 150.44007444381714\n", + "chr7 155.55760312080383\n", + "chr8 160.3683557510376\n", + "chr8 163.43416213989258\n", + "chr9 167.90313267707825\n", + "chr9 172.0667405128479\n", + "chrX 176.69336795806885\n", + "chrX 181.83150935173035\n", + "K562 0.007167339324951172\n", + "chr1 8.471662998199463\n", + "chr1 13.464861631393433\n", + "chr10 17.858335494995117\n", + "chr10 20.700791835784912\n", + "chr11 25.168848276138306\n", + "chr11 28.01260733604431\n", + "chr12 32.38129758834839\n", + "chr12 35.250038385391235\n", + "chr13 38.72063398361206\n", + "chr13 43.30442762374878\n", + "chr14 46.55065989494324\n", + "chr14 51.87103271484375\n", + "chr15 55.08980083465576\n", + "chr15 57.35198903083801\n", + "chr16 60.444990396499634\n", + "chr16 62.56146717071533\n", + "chr17 65.33607196807861\n", + "chr17 75.77480912208557\n", + "chr18 78.25007915496826\n", + "chr18 82.4424319267273\n", + "chr19 84.73718905448914\n", + "chr19 86.0900673866272\n", + "chr2 93.6916708946228\n", + "chr2 98.61803960800171\n", + "chr20 100.70567536354065\n", + "chr20 102.18551921844482\n", + "chr21 103.75095820426941\n", + "chr21 104.96330642700195\n", + "chr22 106.666348695755\n", + "chr22 108.20869731903076\n", + "chr3 114.6058874130249\n", + "chr3 123.16646194458008\n", + "chr4 129.07538533210754\n", + "chr4 135.95439338684082\n", + "chr5 141.63543701171875\n", + "chr5 148.8255476951599\n", + "chr6 154.68585968017578\n", + "chr6 160.3087387084961\n", + "chr7 165.7410364151001\n", + "chr7 169.09255123138428\n", + "chr8 173.68864274024963\n", + "chr8 176.73100185394287\n", + "chr9 181.10383462905884\n", + "chr9 184.0267071723938\n", + "chrX 188.59823846817017\n", + "chrX 191.7538366317749\n" + ] + } + ], + "source": [ + "### import pyBigWig\n", + "import glob\n", + "\n", + "dnases = {}\n", + "celltypes = ['A549', 'GM12878', 'H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']\n", + "\n", + "for ctype in celltypes:#glob.glob('dnase_bigwigs/*'):\n", + " itime = time.time()\n", + " # ctype = pth.split('/')[1].split('.')[1]\n", + " if ctype not in ['liver', 'MCF-7', 'K562']:\n", + " continue\n", + " bw = pyBigWig.open(\"dnase_bigwigs/DNASE.{}.fc.signal.bigwig\".format(ctype))\n", + " chromsizes = bw.chroms()\n", + " print(ctype, time.time() - itime)\n", + " dn_dict = {}\n", + " for chrom in chromsizes: #chr_IDs:\n", + " x = bw.values(chrom, 0, chromsizes[chrom], numpy=True)\n", + " dn_dict[chrom] = np.nan_to_num(x).astype(np.float16) # half-precision makes things significantly smaller (less time to load)\n", + " print(chrom, time.time() - itime)\n", + " \n", + " np.save('dnase/{}/{}.npy'.format(ctype, chrom), dn_dict[chrom])\n", + " print(chrom, time.time() - itime)\n", + " dnases[ctype] = dn_dict\n", + "\n", + "for ctype in dnases:\n", + " itime = time.time()\n", + " print(ctype)\n", + " dn_dict = dnases[ctype]\n", + " \n", + " # Save as npz archive\n", + " np.savez_compressed('codalab_archive/{}_dnase'.format(ctype), **dn_dict)\n", + " print(time.time() - itime)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/dataset_preprocessing/encode/prep_metadata_labels.py b/dataset_preprocessing/encode/prep_metadata_labels.py new file mode 100644 index 00000000..8d1fc537 --- /dev/null +++ b/dataset_preprocessing/encode/prep_metadata_labels.py @@ -0,0 +1,146 @@ +import os, csv +import scipy, numpy as np, pandas as pd, time +from scipy import sparse +import pyBigWig + +# Human chromosome names +chr_IDs = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', + 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', + 'chr20', 'chr21', 'chr22', 'chrX'] +chrom_sizes = {'chr1': 249250621, 'chr10': 135534747, 'chr11': 135006516, 'chr12': 133851895, 'chr13': 115169878, 'chr14': 107349540, 'chr15': 102531392, 'chr16': 90354753, 'chr17': 81195210, 'chr18': 78077248, 'chr19': 59128983, 'chr2': 243199373, 'chr20': 63025520, 'chr21': 48129895, 'chr22': 51304566, 'chr3': 198022430, 'chr4': 191154276, 'chr5': 180915260, 'chr6': 171115067, 'chr7': 159138663, 'chr8': 146364022, 'chr9': 141213431, 'chrX': 155270560} + +_data_dir = '../../examples/data/encode_v1.0/' + + +def write_label_bigwigs( + celltypes, + train_suffix='train.labels.tsv.gz', + val_suffix='val.labels.tsv.gz', + tf_name='MAX' +): + itime = time.time() + + # Read in metadata dataframe from training+validation data + train_regions_labeled = pd.read_csv(os.path.join(_data_dir, 'labels/{}.{}'.format(tf_name, train_suffix)), sep='\t') + val_regions_labeled = pd.read_csv(os.path.join(_data_dir, 'labels/{}.{}'.format(tf_name, val_suffix)), sep='\t') + training_df = train_regions_labeled + val_df = val_regions_labeled + all_df = pd.concat([training_df, val_df]) + + # Get the y values, and remove negative labels by default. + pd_list = [] + for ct in celltypes: + tc_chr = all_df[['chr', 'start', 'stop', ct]] + tc_chr.columns = ['chr', 'start', 'stop', 'y'] + tc_chr = tc_chr[tc_chr['y'] != 'U'] + tc_chr['y'] = tc_chr['y'].replace({'U': 0, 'B': 1, 'A': 0.5}).values + + tc_chr.insert(len(tc_chr.columns), 'celltype', ct) + pd_list.append(tc_chr) + print(ct, time.time() - itime) + _metadata_df = pd.concat(pd_list) + + print(time.time() - itime) + _unsorted_dir = _data_dir + 'labels/{}/{}_posamb.bed'.format( + tf_name, tf_name) + _sorted_dir = _unsorted_dir.replace( + '{}_posamb'.format(tf_name), + '{}_posamb.sorted'.format(tf_name) + ) + _metadata_df.to_csv( + _unsorted_dir, sep='\t', header=False, index=False + ) + print(time.time() - itime) + + # Sort bigwigs (as bed files) in order to convert to bigwig. + os.system('sort -k1,1 -k2,2n {} > {}'.format(_unsorted_dir, _sorted_dir)) + mdf_posamb = pd.read_csv( + _sorted_dir, + sep='\t', header=None, index_col=None, names=['chr', 'start', 'stop', 'y', 'celltype'] + ) + + # Write the binned labels to bigwig files - genome-wide labels + chromsizes_list = [(k, v) for k, v in chrom_sizes.items()] + for ct in celltypes: + ct_labels_bw_path = _data_dir + "labels/{}/{}_{}.bigwig".format( + tf_name, tf_name, ct) + df = mdf_posamb[mdf_posamb['celltype'] == ct] + bw = pyBigWig.open(ct_labels_bw_path, "w") + bw.addHeader(chromsizes_list) + bw.addEntries(list(df['chr']), list(df['start']), ends=list(df['start']+50), values=list(df['y'])) + print(ct, time.time() - itime) + bw.close() + + +def write_metadata_products( + celltypes, + bed_df_filename='metadata_df.bed', + y_arr_filename='metadata_y.npy', + stride=6400, + tf_name='MAX', + posamb_only=False +): + itime = time.time() + celltype_mdta = [] + celltype_labels = [] + if posamb_only: + mdf_posamb = pd.read_csv( + _data_dir + 'labels/{}/{}_posamb.sorted.bed'.format(tf_name, tf_name), + sep='\t', header=None, index_col=None, names=['chr', 'start', 'stop', 'y', 'celltype'] + ) + # Retrieve only the windows containing positively/ambiguously labeled bins (if posamb_only==True), or all windows (if posamb_only==False). + for ct in celltypes: + ct_labels_bw_path = _data_dir + "labels/{}/{}_{}.bigwig".format(tf_name, tf_name, ct) + df_construction = [] + mdta_labels = [] + bw = pyBigWig.open(ct_labels_bw_path) + if posamb_only: # Retrieve only the windows containing positively/ambiguously labeled bins + df = mdf_posamb[mdf_posamb['celltype'] == ct] + df['window_start'] = stride*(df['start'] // stride) + uniq_windows = np.unique(["{}:{}".format(x[0], x[1]) for x in zip(df['chr'], df['window_start'])]) + for u in uniq_windows: + u_chr = u.split(':')[0] + u_start = int(u.split(':')[1]) + u_end = u_start + stride + x = np.nan_to_num(bw.values(u_chr, u_start, u_end, numpy=True)) + df_construction.append((u_chr, u_start, u_end)) + mdta_labels.append(x[np.arange(0, len(x), 50)]) + else: # Retrieve all windows genome-wide + for chrID in bw.chroms(): + chromsize = bw.chroms()[chrID] + # Iterate over windows + for startc in np.arange(int(stride/2), chromsize-(2*stride), stride): + u_end = startc + stride + if u_end > chromsize: + break + x = np.nan_to_num(bw.values(chrID, startc, u_end, numpy=True)) + df_construction.append((chrID, startc, u_end)) + mdta_labels.append(x[np.arange(0, len(x), 50)]) + print(ct, chrID, time.time() - itime) + celltype_mdta_df = pd.DataFrame(df_construction, columns=['chr', 'start', 'stop']) + celltype_mdta_df.insert(len(celltype_mdta_df.columns), 'celltype', ct) + celltype_mdta.append(celltype_mdta_df) + celltype_labels.append(np.stack(mdta_labels)) + print(ct, time.time() - itime) + bw.close() + print(time.time() - itime) + + all_metadata_df = pd.concat(celltype_mdta) + all_metadata_df.to_csv( + _data_dir + 'labels/{}/{}'.format(tf_name, bed_df_filename), + sep='\t', header=False, index=False + ) + np.save(_data_dir + 'labels/{}/{}'.format(tf_name, y_arr_filename), np.vstack(celltype_labels)) + + +if __name__ == '__main__': + tf_name = 'JUND' + tfs_to_celltypes = { + 'MAX': ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562', 'A549', 'GM12878', 'liver'], + 'REST': ['H1-hESC', 'HeLa-S3', 'HepG2', 'MCF-7', 'Panc1', 'liver'], + 'JUND': ['HCT116', 'HeLa-S3', 'HepG2', 'K562', 'MCF-7', 'liver'] + } + all_celltypes = tfs_to_celltypes[tf_name] + write_label_bigwigs([x for x in all_celltypes if x != 'liver'], tf_name=tf_name) + write_label_bigwigs(['liver'], train_suffix='train_wc.labels.tsv.gz', val_suffix='test.labels.tsv.gz', tf_name=tf_name) + write_metadata_products(all_celltypes, tf_name=tf_name) diff --git a/dataset_preprocessing/encode/prep_sequence.py b/dataset_preprocessing/encode/prep_sequence.py new file mode 100644 index 00000000..b80be0da --- /dev/null +++ b/dataset_preprocessing/encode/prep_sequence.py @@ -0,0 +1,131 @@ +import argparse, time +import numpy as np + +from tqdm import tqdm + +# Sequence preprocessing. Code adapted from Jacob Schreiber. + +# Human chromosome names +chr_IDs = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', + 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', + 'chr20', 'chr21', 'chr22', 'chrX'] + +def one_hot_encode(sequence, ignore='N', alphabet=None, dtype='int8', verbose=False, **kwargs): + """ + Converts a string or list of characters into a one-hot encoding. + This function will take in either a string or a list and convert it into a one-hot encoding. If the input is a string, each character is assumed to be a different symbol, e.g. 'ACGT' is assumed to be a sequence of four characters. If the input is a list, the elements can be any size. + Although this function will be used here primarily to convert nucleotide sequences into one-hot encoding with an alphabet of size 4, in principle this function can be used for any types of sequences. + + Parameters + ---------- + sequence : str or list + The sequence to convert to a one-hot encoding. + ignore : str, optional + A character to indicate setting nothing to 1 for that row, keeping the encoding entirely 0's for that row. In the context of genomics, this is the N character. Default is 'N'. + alphabet : set or tuple or list, optional + A pre-defined alphabet. If None is passed in, the alphabet will be determined from the sequence, but this may be time consuming for large sequences. Default is None. + dtype : str or numpy.dtype, optional + The data type of the returned encoding. Default is int8. + verbose : bool or str, optional + Whether to display a progress bar. If a string is passed in, use as the name of the progressbar. Default is False. + kwargs : arguments + Arguments to be passed into tqdm. Default is None. + + Returns + ------- + ohe : numpy.ndarray + A binary matrix of shape (alphabet_size, sequence_length) where alphabet_size is the number of unique elements in the sequence and sequence_length is the length of the input sequence. + """ + + name = None if verbose in (True, False) else verbose + d = verbose is False + + if isinstance(sequence, str): + sequence = list(sequence) + + alphabet = alphabet or np.unique(sequence) + alphabet = [char for char in alphabet if char != ignore] + alphabet_lookup = {char: i for i, char in enumerate(alphabet)} + + ohe = np.zeros((len(sequence), len(alphabet)), dtype=dtype) + for i, char in tqdm(enumerate(sequence), disable=d, desc=name, **kwargs): + if char != ignore: + idx = alphabet_lookup[char] + ohe[i, idx] = 1 + + return ohe + + +def read_fasta(filename, include_chroms=None, exclude_chroms=None, ignore='N', alphabet=['A', 'C', 'G', 'T', 'N'], verbose=True): + """ + Read in a FASTA file and output a dictionary of sequences. + This function will take in the path to a FASTA-formatted file and output a string containing the sequence for each chromosome. Optionally, the user can specify a set of chromosomes to include or exclude from the returned dictionary. + + Parameters + ---------- + filename : str + The path to the FASTA-formatted file to open. + include_chroms : set or tuple or list, optional + The exact names of chromosomes in the FASTA file to include, excluding all others. If None, include all chromosomes (except those specified by exclude_chroms). Default is None. + exclude_chroms : set or tuple or list, optional + The exact names of chromosomes in the FASTA file to exclude, including all others. If None, include all chromosomes (or the set specified by include_chroms). Default is None. + ignore : str, optional + A character to indicate setting nothing to 1 for that row, keeping the encoding entirely 0's for that row. In the context of genomics, this is the N character. Default is 'N'. + alphabet : set or tuple or list, optional + A pre-defined alphabet. If None is passed in, the alphabet will be determined from the sequence, but this may be time consuming for large sequences. Must include the ignore character. Default is ['A', 'C', 'G', 'T', 'N']. + verbose : bool or str, optional + Whether to display a progress bar. If a string is passed in, use as the name of the progressbar. Default is False. + + Returns + ------- + chroms : dict + A dictionary of strings where the keys are the names of the chromosomes (exact strings from the header lines in the FASTA file) and the values are the strings encoded there. + """ + + sequences = {} + name, sequence = None, None + skip_chrom = False + + with open(filename, "r") as infile: + for line in tqdm(infile, disable=not verbose): + if line.startswith(">"): + if name is not None and skip_chrom is False: + sequences[name] = ''.join(sequence) + sequence = [] + name = line[1:].strip("\n") + if include_chroms is not None and name not in include_chroms: + skip_chrom = True + elif exclude_chroms is not None and name in exclude_chroms: + skip_chrom = True + else: + skip_chrom = False + else: + if skip_chrom == False: + sequence.append(line.rstrip("\n").upper()) + return sequences + + +def generate_sequence_archive(seq_path='sequence/hg19.genome.fa', output_dir): + fasta_contents = read_fasta() + kw_dict = {} + itime = time.time() + for chrom in chr_IDs: + seqstr = fasta_contents[chrom] + kw_dict[chrom] = one_hot_encode(seqstr, alphabet=['A', 'C', 'G', 'T', 'N']) + print(chrom, time.time() - itime) + + # Save as npz archive; can take several (>20) minutes + print("Saving npz archive...") + np.savez_compressed('{}/sequence'.format(output_root), **kw_dict) + print(time.time() - itime) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--seq_path', required=True) + parser.add_argument('--output_dir', required=True) + args = parser.parse_args() + + generate_sequence_archive( + seq_path=args.seq_path, + output_dir=args.output_dir) diff --git a/wilds/datasets/encode_dataset.py b/wilds/datasets/encode_dataset.py new file mode 100644 index 00000000..b2301152 --- /dev/null +++ b/wilds/datasets/encode_dataset.py @@ -0,0 +1,425 @@ +import os, time +import torch +import pandas as pd +import numpy as np +import pyBigWig +from wilds.datasets.wilds_dataset import WILDSDataset +from wilds.common.utils import subsample_idxs +from wilds.common.grouper import CombinatorialGrouper +from wilds.common.metrics.all_metrics import MultiTaskAveragePrecision + +# Human chromosomes in hg19 +chrom_sizes = {'chr1': 249250621, 'chr10': 135534747, 'chr11': 135006516, 'chr12': 133851895, 'chr13': 115169878, 'chr14': 107349540, 'chr15': 102531392, 'chr16': 90354753, 'chr17': 81195210, 'chr18': 78077248, 'chr19': 59128983, 'chr2': 243199373, 'chr20': 63025520, 'chr21': 48129895, 'chr22': 51304566, 'chr3': 198022430, 'chr4': 191154276, 'chr5': 180915260, 'chr6': 171115067, 'chr7': 159138663, 'chr8': 146364022, 'chr9': 141213431, 'chrX': 155270560} + +# quantile normalization via numpy inter/extra-polation +def anchor(input_data, sample, ref): # input 1d array + sample.sort() + ref.sort() + # 0. create the mapping function + index = np.array(np.where(np.diff(sample) != 0)) + 1 + index = index.flatten() + x = np.concatenate((np.zeros(1), sample[index])) # domain + y = np.zeros(len(x)) # codomain + for i in np.arange(0,len(index)-1, 1): + start = index[i] + end = index[i+1] + y[i+1] = np.mean(ref[start:end]) + i += 1 + start = index[i] + end = len(ref) + y[i+1] = np.mean(ref[start:end]) + # 1. interpolate + output = np.interp(input_data, x, y) + # 2. extrapolate + degree = 1 # degree of the fitting polynomial + num = 10 # number of positions for extrapolate + f1 = np.poly1d(np.polyfit(sample[-num:],ref[-num:],degree)) + output[input_data > sample[-1]] = f1(input_data[input_data > sample[-1]]) + return output + + +def wrap_anchor( + signal, + sample, + ref +): + ## 1.format as bigwig first + x = signal + z = np.concatenate(([0],x,[0])) # pad two zeroes + # find boundary + starts = np.where(np.diff(z) != 0)[0] + ends = starts[1:] + starts = starts[:-1] + vals = x[starts] + if starts[0] != 0: + ends = np.concatenate(([starts[0]],ends)) + starts = np.concatenate(([0],starts)) + vals = np.concatenate(([0],vals)) + if ends[-1] != len(signal): + starts = np.concatenate((starts,[ends[-1]])) + ends = np.concatenate((ends,[len(signal)])) + vals = np.concatenate((vals,[0])) + + ## 2.then quantile normalization + vals_anchored = anchor(vals, sample, ref) + return vals_anchored, starts, ends + + +def dnase_normalize( + input_bw_celltype, + ref_celltypes, + out_fname, + data_pfx +): + if not data_pfx.endswith('/'): + data_pfx = data_pfx + '/' + itime = time.time() + sample = np.load(data_pfx + "qn.{}.npy".format(input_bw_celltype)) + ref = np.zeros(len(sample)) + for ct in ref_celltypes: + ref += (1.0/len(ref_celltypes))*np.load(data_pfx + "qn.{}.npy".format(ct)) + + chromsizes_list = [(k, v) for k, v in chrom_sizes.items()] + bw_output = pyBigWig.open(out_fname, 'w') + bw_output.addHeader(chromsizes_list) + + for the_chr in chrom_sizes: + signal = np.zeros(chrom_sizes[the_chr]) + bw = pyBigWig.open(data_pfx + 'DNASE.{}.fc.signal.bigwig'.format(input_bw_celltype)) + signal += np.nan_to_num(np.array(bw.values(the_chr, 0, chrom_sizes[the_chr]))) + bw.close() + vals_anchored, starts, ends = wrap_anchor(signal, sample, ref) + # write normalized dnase file. + chroms = np.array([the_chr] * len(vals_anchored)) + bw_output.addEntries(chroms, starts, ends=ends, values=vals_anchored) + print(input_bw_celltype, the_chr, time.time() - itime) + + bw_output.close() + + +class EncodeDataset(WILDSDataset): + """ + ENCODE dataset of transcription factor binding sites. + This is a subset of the dataset from the ENCODE-DREAM in vivo Transcription Factor Binding Site Prediction Challenge. + + Input (x): + 12800-base-pair regions of sequence with a quantified chromatin accessibility readout. + + Label (y): + y is a 128-bit vector, with each element y_i indicating the binding status of a 200bp window. It is 1 if this 200bp region is bound by the transcription factor, and 0 otherwise, for i = 0,1,...,127. + + Concretely, suppose the input window x starts at coordinate sc, extending until coordinate (sc+12800). Then y_i is the label of the window starting at coordinate (sc+3200)+(50*i). + + Metadata: + Each sequence is annotated with the celltype of origin (a string) and the chromosome of origin (a string). + + Website: + https://www.synapse.org/#!Synapse:syn6131484 . This is the website for the challenge; the data can be downloaded from here as per the instructions in dataset_preprocessing/encode/README.md. + """ + + _dataset_name = 'encode' + _versions_dict = { + '1.0': { + 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x9c282b6e9082440f9dcd61bb605c1eab/contents/blob/', + 'compressed_size': None}} + + def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'): + itime = time.time() + self._version = version + self._data_dir = self.initialize_data_dir(root_dir, download) + self._y_size = 128 + + # Construct splits + train_chroms = ['chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr10', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX'] + val_chroms = ['chr2', 'chr9', 'chr11'] + test_chroms = ['chr1', 'chr8', 'chr21'] + official_train_cts = { + 'MAX': ['H1-hESC', 'HCT116', 'HeLa-S3', 'K562', 'A549', 'GM12878'], + 'JUND': ['HCT116', 'HeLa-S3', 'K562', 'MCF-7'] + } + official_val_cts = { + 'MAX': ['HepG2'], 'JUND': ['HepG2'] + } + official_test_cts = { + 'MAX': ['liver'], 'JUND': ['liver'] + } + + # Set the TF in split_scheme by prefacing it with 'tf..' + self._transcription_factor = 'MAX' + if 'tf.' in split_scheme: + tkns = split_scheme.split('.') + self._transcription_factor = tkns[1] + split_scheme = '.'.join(tkns[2:]) + self._split_scheme = split_scheme + + train_celltypes = official_train_cts[self._transcription_factor] + val_celltype = official_val_cts[self._transcription_factor] + test_celltype = official_test_cts[self._transcription_factor] + + if self._split_scheme == 'official': + splits = { + 'train': { + 'chroms': train_chroms, + 'celltypes': train_celltypes + }, + 'id_val': { + 'chroms': val_chroms, + 'celltypes': train_celltypes + }, + 'val': { + 'chroms': val_chroms, + 'celltypes': val_celltype + }, + 'test': { + 'chroms': test_chroms, + 'celltypes': test_celltype + }, + 'id_test': { + 'chroms': test_chroms, + 'celltypes': train_celltypes + } + } + self._split_dict = { + 'train': 0, + 'val': 1, + 'test': 2, + 'id_val': 3, + 'id_test': 4 + } + self._split_names = { + 'train': 'Train', + 'val': 'Validation (OOD)', + 'test': 'Test', + 'id_val': 'Validation (ID)', + 'id_test': 'Test (ID)', + } + elif self._split_scheme == 'in-dist': + splits = { + 'train': { + 'chroms': train_chroms, + 'celltypes': test_celltype, + }, + 'val': { + 'chroms': val_chroms, + 'celltypes': test_celltype + }, + 'test': { + 'chroms': test_chroms, + 'celltypes': test_celltype + }, + } + self._split_dict = { + 'train': 0, + 'val': 1, + 'test': 2, + } + self._split_names = { + 'train': 'Train', + 'val': 'Validation (OOD)', + 'test': 'Test', + } + elif 'id-' in self._split_scheme: + test_celltype = [ self._split_scheme.split('id-')[1] ] + splits = { + 'train': { + 'chroms': train_chroms, + 'celltypes': test_celltype, + }, + 'val': { + 'chroms': val_chroms, + 'celltypes': test_celltype + }, + 'test': { + 'chroms': test_chroms, + 'celltypes': test_celltype + }, + } + self._split_dict = { + 'train': 0, + 'val': 1, + 'test': 2, + } + self._split_names = { + 'train': 'Train', + 'val': 'Validation (OOD)', + 'test': 'Test', + } + + # Add new split scheme specifying custom test and val celltypes in the format val..test., e.g. self._split_scheme == 'official' is equivalent to self._split_scheme == 'val.HepG2.test.liver' + elif '.' in self._split_scheme: + all_celltypes = train_celltypes + val_celltype + test_celltype + in_val_ct = self._split_scheme.split('.')[1] + in_test_ct = self._split_scheme.split('.')[3] + train_celltypes = [ct for ct in all_celltypes if ((ct != in_val_ct) and (ct != in_test_ct))] + val_celltype = [in_val_ct] + test_celltype = [in_test_ct] + splits = { + 'train': { + 'chroms': train_chroms, + 'celltypes': train_celltypes + }, + 'id_val': { + 'chroms': val_chroms, + 'celltypes': train_celltypes + }, + 'val': { + 'chroms': val_chroms, + 'celltypes': val_celltype + }, + 'test': { + 'chroms': test_chroms, + 'celltypes': test_celltype + }, + 'id_test': { + 'chroms': test_chroms, + 'celltypes': train_celltypes + } + } + self._split_dict = { + 'train': 0, + 'val': 1, + 'test': 2, + 'id_val': 3, + 'id_test': 4 + } + self._split_names = { + 'train': 'Train', + 'val': 'Validation (OOD)', + 'test': 'Test', + 'id_val': 'Validation (ID)', + 'id_test': 'Test (ID)', + } + else: + raise ValueError(f'Split scheme {self._split_scheme} not recognized') + + # Read in metadata and labels + self._metadata_df = pd.read_csv( + self._data_dir + '/labels/{}/metadata_df.bed'.format(self._transcription_factor), + sep='\t', header=None, + index_col=None, names=['chr', 'start', 'stop', 'celltype'] + ) + self._y_array = torch.tensor(np.load( + self._data_dir + '/labels/{}/metadata_y.npy'.format(self._transcription_factor))) + + # ~10% of the dataset has ambiguous labels + # i.e., we can't tell if there is a binding event or not. + # This typically happens at the flanking regions of peaks. + # For our purposes, we will ignore these ambiguous labels during training and eval. + self.y_array[self.y_array == 0.5] = float('nan') + + self._split_array = -1 * np.ones(self._metadata_df.shape[0]).astype(int) + for split, d in splits.items(): + chrom_mask = np.isin(self._metadata_df['chr'], d['chroms']) + celltype_mask = np.isin(self._metadata_df['celltype'], d['celltypes']) + self._split_array[chrom_mask & celltype_mask] = self._split_dict[split] + + keep_mask = (self._split_array != -1) + + # Remove all-zero sequences from training. + train_mask = (self._split_array == self._split_dict['train']) + allzeroes_mask = (self._y_array.sum(axis=1) == 0).numpy() + keep_mask = keep_mask & ~(train_mask & allzeroes_mask) + + # Subsample the testing and validation indices, to speed up evaluation. + # For the OOD splits (val and test), we subsample by a factor of 3 + # For the id_val and id_test splits, we subsample by a factor of 3*(# of training celltypes) + for subsample_seed, (split, subsample_factor) in enumerate([ + ('val', 3), + ('test', 3), + ('id_val', 3*len(splits['train']['celltypes'])), + ('id_test', 3*len(splits['train']['celltypes']))]): + if split not in self._split_dict: continue + split_mask = (self._split_array == self._split_dict[split]) + split_idxs = np.arange(len(self._split_array))[split_mask] + idxs_to_remove = subsample_idxs( + split_idxs, + num=len(split_idxs) // subsample_factor, + seed=subsample_seed, + take_rest=True) + keep_mask[idxs_to_remove] = False + + self._metadata_df = self._metadata_df[keep_mask] + self._split_array = self._split_array[keep_mask] + self._y_array = self._y_array[keep_mask] + + self._all_chroms = sorted(list({chrom for _, d in splits.items() for chrom in d['chroms']})) + self._all_celltypes = sorted(list({chrom for _, d in splits.items() for chrom in d['celltypes']})) + + # Load sequence into memory + sequence_filename = os.path.join(self._data_dir, 'sequence.npz') + seq_arr = np.load(sequence_filename) + self._seq_bp = {} + for chrom in self._all_chroms: + self._seq_bp[chrom] = seq_arr[chrom] + print(chrom, time.time() - itime) + del seq_arr + + # Set up file handles for DNase features, writing normalized DNase tracks along the way if they aren't already written. + self._dnase_allcelltypes = {} + for ct in self._all_celltypes: + orig_dnase_bw_path = os.path.join(self._data_dir, 'DNASE.{}.fc.signal.bigwig'.format(ct)) + dnase_bw_path = os.path.join(self._data_dir, 'DNase.{}.{}.{}.bigwig'.format(self._transcription_factor, ct, self._split_scheme)) + if not os.path.exists(dnase_bw_path): + ref_celltypes = splits['train']['celltypes'] + dnase_normalize(ct, ref_celltypes, out_fname=dnase_bw_path, data_pfx=self._data_dir) + self._dnase_allcelltypes[ct] = pyBigWig.open(dnase_bw_path) + + # Load subsampled DNase arrays for normalization purposes + self._dnase_qnorm_arrays = {} + for ct in self._all_celltypes: + qnorm_arr_path = os.path.join(self._data_dir, 'qn.{}.npy'.format(ct)) + self._dnase_qnorm_arrays[ct] = np.load(qnorm_arr_path) + self._norm_ref_distr = np.zeros(len(self._dnase_qnorm_arrays[ct])) + test_cts = splits['test']['celltypes'] + num_to_avg = len(self._all_celltypes) - len(test_cts) + for ct in self._all_celltypes: + if ct not in test_cts: + self._norm_ref_distr += (1.0/num_to_avg)*self._dnase_qnorm_arrays[ct] + + # Set up metadata fields, map, array + self._metadata_fields = ['chr', 'celltype'] + self._metadata_map = {} + self._metadata_map['chr'] = self._all_chroms + self._metadata_map['celltype'] = self._all_celltypes + chr_ints = self._metadata_df['chr'].replace(dict( [(y, x) for x, y in enumerate(self._metadata_map['chr'])] )).values + celltype_ints = self._metadata_df['celltype'].replace(dict( [(y, x) for x, y in enumerate(self._metadata_map['celltype'])] )).values + self._metadata_array = torch.stack( + (torch.LongTensor(chr_ints), + torch.LongTensor(celltype_ints) + ), + dim=1) + + self._eval_grouper = CombinatorialGrouper( + dataset=self, + groupby_fields=['celltype']) + + self._metric = MultiTaskAveragePrecision() + + super().__init__(root_dir, download, split_scheme) + + def get_input(self, idx, window_size=12800): + """ + Returns x for a given idx in metadata_array, which has been filtered to only take windows with the desired stride. + Computes this from: + (1) sequence features in self._seq_bp + (2) DNase bigwig file handles in self._dnase_allcelltypes + (3) Metadata for the index (location along the genome with 6400bp window width) + (4) Window_size, the length of sequence returned (centered on the 6400bp region in (3)) + """ + this_metadata = self._metadata_df.iloc[idx, :] + chrom = this_metadata['chr'] + interval_start = this_metadata['start'] - int(window_size/4) + interval_end = interval_start + window_size + seq_this = self._seq_bp[this_metadata['chr']][interval_start:interval_end] + dnase_bw = self._dnase_allcelltypes[this_metadata['celltype']] + dnase_this = np.nan_to_num(dnase_bw.values(chrom, interval_start, interval_end, numpy=True)) + return torch.tensor(np.column_stack( + [seq_this, + dnase_this] + ).T) + + def eval(self, y_pred, y_true, metadata): + return self.standard_group_eval( + self._metric, + self._eval_grouper, + y_pred, y_true, metadata) From cd0196610fe53d0ec5a4d023c4ef7e4923345f51 Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Tue, 6 Jul 2021 23:05:04 -0700 Subject: [PATCH 214/244] set_grad_enabled change --- examples/algorithms/algorithm.py | 4 +--- examples/train.py | 5 ++++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/algorithms/algorithm.py b/examples/algorithms/algorithm.py index 5c734766..2a738e0a 100644 --- a/examples/algorithms/algorithm.py +++ b/examples/algorithms/algorithm.py @@ -40,14 +40,12 @@ def evaluate(self, batch): """ raise NotImplementedError - # Taken from domainbed def train(self, mode=True): """ Switch to train mode """ self.is_training = mode - super().train(mode) - torch.set_grad_enabled(mode) + super().train(mode) self.reset_log() @property diff --git a/examples/train.py b/examples/train.py index 5b0b591e..c1caa3fd 100644 --- a/examples/train.py +++ b/examples/train.py @@ -11,8 +11,10 @@ def run_epoch(algorithm, dataset, general_logger, epoch, config, train): if train: algorithm.train() + torch.set_grad_enabled(True) else: algorithm.eval() + torch.set_grad_enabled(False) # Not preallocating memory is slower # but makes it easier to handle different types of data loaders @@ -30,7 +32,7 @@ def run_epoch(algorithm, dataset, general_logger, epoch, config, train): if train: batch_results = algorithm.update(batch) else: - batch_results = algorithm.evaluate(batch) + batch_results = algorithm.evaluate(batch) # These tensors are already detached, but we need to clone them again # Otherwise they don't get garbage collected properly in some versions @@ -115,6 +117,7 @@ def train(algorithm, datasets, general_logger, config, epoch_offset, best_val_me def evaluate(algorithm, datasets, epoch, general_logger, config, is_best): algorithm.eval() + torch.set_grad_enabled(False) for split, dataset in datasets.items(): if (not config.evaluate_all_splits) and (split not in config.eval_splits): continue From 353f5e1a9adccd40a92b00db18b25bc867c0828b Mon Sep 17 00:00:00 2001 From: Tony Lee Date: Wed, 7 Jul 2021 21:37:21 -0700 Subject: [PATCH 215/244] support rxrx1 and globalwheat in the evaluation script --- examples/evaluate.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/evaluate.py b/examples/evaluate.py index a5e7a7e3..e7c4cacc 100644 --- a/examples/evaluate.py +++ b/examples/evaluate.py @@ -106,6 +106,10 @@ def get_metrics(dataset_name: str) -> List[str]: return ["r_wg", "r_all"] elif "py150" == dataset_name: return ["acc", "Acc (Overall)"] + elif "globalwheat" == dataset_name: + return ["detection_acc_avg_dom"] + elif "rxrx1" == dataset_name: + return ["acc_avg", "acc_wg"] else: raise ValueError(f"Invalid dataset: {dataset_name}") From baf48d30a470dddcda85854bd7f2848d00d20338 Mon Sep 17 00:00:00 2001 From: Tony Lee Date: Sat, 10 Jul 2021 11:53:33 -0700 Subject: [PATCH 216/244] add globalwheat to benchmark_datasets --- wilds/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/wilds/__init__.py b/wilds/__init__.py index a1f28996..d4a114be 100644 --- a/wilds/__init__.py +++ b/wilds/__init__.py @@ -11,6 +11,7 @@ 'fmow', 'py150', 'rxrx1', + 'globalwheat', ] additional_datasets = [ From 6c4552d875c00596b497da660044159a1f5bc3c0 Mon Sep 17 00:00:00 2001 From: aikanor Date: Sat, 10 Jul 2021 13:00:43 -0700 Subject: [PATCH 217/244] batch size --- examples/configs/datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index 47e502a8..57bd3298 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -121,7 +121,7 @@ 'scheduler_kwargs': {'milestones':[3,6], 'gamma': 0.1}, # used to be 6, 9, with 12 epochs # 'scheduler': 'linear_schedule_with_warmup', # 'scheduler_kwargs': {'num_warmup_steps': 800}, # about 160 minibatches per epoch - 'batch_size': 256, + 'batch_size': 128, 'lr': 1e-3, 'weight_decay': 1e-4, 'n_epochs': 12, From ad312893949fc9d0f2bab66f60302be92a7194d0 Mon Sep 17 00:00:00 2001 From: aikanor Date: Sat, 10 Jul 2021 23:46:43 -0700 Subject: [PATCH 218/244] code cleanup --- .../encode/prep_accessibility.py | 1 - .../encode/prep_datasets.ipynb | 279 ------------------ .../encode/prep_metadata_labels.py | 15 +- examples/configs/datasets.py | 4 +- examples/configs/supported.py | 2 +- 5 files changed, 9 insertions(+), 292 deletions(-) delete mode 100644 dataset_preprocessing/encode/prep_datasets.ipynb diff --git a/dataset_preprocessing/encode/prep_accessibility.py b/dataset_preprocessing/encode/prep_accessibility.py index a3303317..6afcab2c 100644 --- a/dataset_preprocessing/encode/prep_accessibility.py +++ b/dataset_preprocessing/encode/prep_accessibility.py @@ -120,7 +120,6 @@ def dnase_normalize( bw_output = pyBigWig.open(data_pfx + 'DNase.{}.{}.bigwig'.format( input_bw_celltype, out_fname), 'w') bw_output.addHeader(chromsizes_list) - # bw_output.addHeader(list(zip(chr_all , num_bp)), maxZooms=0) # zip two turples for the_chr in chrom_sizes: signal = np.zeros(chrom_sizes[the_chr]) diff --git a/dataset_preprocessing/encode/prep_datasets.ipynb b/dataset_preprocessing/encode/prep_datasets.ipynb deleted file mode 100644 index 78235fd7..00000000 --- a/dataset_preprocessing/encode/prep_datasets.ipynb +++ /dev/null @@ -1,279 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import prep_utils, scipy, numpy as np, time\n", - "from scipy import sparse\n", - "\n", - "# Human chromosome names\n", - "chr_IDs = ['chr1', 'chr2', 'chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr8', 'chr9', 'chr10', 'chr11', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr21', 'chr22', 'chrX']" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Sequence" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "62743362it [00:54, 1151676.47it/s]\n" - ] - } - ], - "source": [ - "a = prep_utils.read_fasta('sequence/hg19.genome.fa')\n", - "\n", - "kw_dict = {}\n", - "itime = time.time()\n", - "for chrom in chr_IDs:\n", - " seqstr = a[chrom]\n", - " kw_dict[chrom] = prep_utils.one_hot_encode(seqstr, alphabet=['A', 'C', 'G', 'T', 'N'])\n", - " print(chrom, time.time() - itime)\n", - "\n", - "# Save as npz archive; can take several (>20) minutes\n", - "print(\"Saving npz archive...\")\n", - "np.savez_compressed('codalab_archive/sequence', **kw_dict)\n", - "print(time.time() - itime)\n", - "\n", - "# # Save as npy arrays\n", - "# itime = time.time()\n", - "# for chrom in kw_dict:\n", - "# np.save('sequence/{}.npy'.format(chrom), kw_dict[chrom])\n", - "# print(chrom, time.time() - itime)\n", - "\n", - "npz_archive = np.load('codalab_archive/sequence.npz')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## DNase" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "liver 0.006468534469604492\n", - "chr1 8.260387659072876\n", - "chr1 13.276052474975586\n", - "chr10 17.844778299331665\n", - "chr10 25.784512758255005\n", - "chr11 30.30143165588379\n", - "chr11 33.256701707839966\n", - "chr12 37.791435956954956\n", - "chr12 40.85292291641235\n", - "chr13 44.619521141052246\n", - "chr13 47.792500495910645\n", - "chr14 51.4214243888855\n", - "chr14 53.6813702583313\n", - "chr15 56.946401834487915\n", - "chr15 59.10466551780701\n", - "chr16 61.939475774765015\n", - "chr16 63.999470472335815\n", - "chr17 66.63648653030396\n", - "chr17 68.4126443862915\n", - "chr18 71.05454993247986\n", - "chr18 72.90085673332214\n", - "chr19 74.78594756126404\n", - "chr19 76.80954170227051\n", - "chr2 85.25815343856812\n", - "chr2 95.36479425430298\n", - "chr20 97.74516272544861\n", - "chr20 99.27151441574097\n", - "chr21 100.82207584381104\n", - "chr21 103.02815318107605\n", - "chr22 104.63926863670349\n", - "chr22 106.02127361297607\n", - "chr3 112.71910071372986\n", - "chr3 117.30491018295288\n", - "chr4 123.77405095100403\n", - "chr4 128.67069339752197\n", - "chr5 134.89299392700195\n", - "chr5 138.83413815498352\n", - "chr6 144.83386087417603\n", - "chr6 149.115407705307\n", - "chr7 154.4929392337799\n", - "chr7 157.8094253540039\n", - "chr8 162.8749077320099\n", - "chr8 165.9331293106079\n", - "chr9 170.5435709953308\n", - "chr9 173.46287417411804\n", - "chrX 178.5410988330841\n", - "chrX 185.49569463729858\n", - "chrY 187.14469981193542\n", - "chrY 189.6306025981903\n", - "MCF-7 0.01819300651550293\n", - "chr1 8.266149282455444\n", - "chr1 13.86928129196167\n", - "chr10 18.216674327850342\n", - "chr10 20.975315809249878\n", - "chr11 25.302175998687744\n", - "chr11 34.40013885498047\n", - "chr12 38.70525503158569\n", - "chr12 41.59175777435303\n", - "chr13 45.130286693573\n", - "chr13 47.67305374145508\n", - "chr14 51.26033353805542\n", - "chr14 53.59153509140015\n", - "chr15 56.858047008514404\n", - "chr15 59.08759665489197\n", - "chr16 62.03992414474487\n", - "chr16 63.99170207977295\n", - "chr17 67.05595779418945\n", - "chr17 69.3644654750824\n", - "chr18 71.78018283843994\n", - "chr18 73.58044695854187\n", - "chr19 75.70175457000732\n", - "chr19 79.72573828697205\n", - "chr2 87.675612449646\n", - "chr2 92.91672372817993\n", - "chr20 95.51653027534485\n", - "chr20 96.88600373268127\n", - "chr21 98.43806076049805\n", - "chr21 103.25369572639465\n", - "chr22 104.84882092475891\n", - "chr22 106.21143817901611\n", - "chr3 112.67947244644165\n", - "chr3 116.70610451698303\n", - "chr4 122.56520342826843\n", - "chr4 126.52856135368347\n", - "chr5 132.38469552993774\n", - "chr5 136.28370690345764\n", - "chr6 141.5743978023529\n", - "chr6 145.10061717033386\n", - "chr7 150.44007444381714\n", - "chr7 155.55760312080383\n", - "chr8 160.3683557510376\n", - "chr8 163.43416213989258\n", - "chr9 167.90313267707825\n", - "chr9 172.0667405128479\n", - "chrX 176.69336795806885\n", - "chrX 181.83150935173035\n", - "K562 0.007167339324951172\n", - "chr1 8.471662998199463\n", - "chr1 13.464861631393433\n", - "chr10 17.858335494995117\n", - "chr10 20.700791835784912\n", - "chr11 25.168848276138306\n", - "chr11 28.01260733604431\n", - "chr12 32.38129758834839\n", - "chr12 35.250038385391235\n", - "chr13 38.72063398361206\n", - "chr13 43.30442762374878\n", - "chr14 46.55065989494324\n", - "chr14 51.87103271484375\n", - "chr15 55.08980083465576\n", - "chr15 57.35198903083801\n", - "chr16 60.444990396499634\n", - "chr16 62.56146717071533\n", - "chr17 65.33607196807861\n", - "chr17 75.77480912208557\n", - "chr18 78.25007915496826\n", - "chr18 82.4424319267273\n", - "chr19 84.73718905448914\n", - "chr19 86.0900673866272\n", - "chr2 93.6916708946228\n", - "chr2 98.61803960800171\n", - "chr20 100.70567536354065\n", - "chr20 102.18551921844482\n", - "chr21 103.75095820426941\n", - "chr21 104.96330642700195\n", - "chr22 106.666348695755\n", - "chr22 108.20869731903076\n", - "chr3 114.6058874130249\n", - "chr3 123.16646194458008\n", - "chr4 129.07538533210754\n", - "chr4 135.95439338684082\n", - "chr5 141.63543701171875\n", - "chr5 148.8255476951599\n", - "chr6 154.68585968017578\n", - "chr6 160.3087387084961\n", - "chr7 165.7410364151001\n", - "chr7 169.09255123138428\n", - "chr8 173.68864274024963\n", - "chr8 176.73100185394287\n", - "chr9 181.10383462905884\n", - "chr9 184.0267071723938\n", - "chrX 188.59823846817017\n", - "chrX 191.7538366317749\n" - ] - } - ], - "source": [ - "### import pyBigWig\n", - "import glob\n", - "\n", - "dnases = {}\n", - "celltypes = ['A549', 'GM12878', 'H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562']\n", - "\n", - "for ctype in celltypes:#glob.glob('dnase_bigwigs/*'):\n", - " itime = time.time()\n", - " # ctype = pth.split('/')[1].split('.')[1]\n", - " if ctype not in ['liver', 'MCF-7', 'K562']:\n", - " continue\n", - " bw = pyBigWig.open(\"dnase_bigwigs/DNASE.{}.fc.signal.bigwig\".format(ctype))\n", - " chromsizes = bw.chroms()\n", - " print(ctype, time.time() - itime)\n", - " dn_dict = {}\n", - " for chrom in chromsizes: #chr_IDs:\n", - " x = bw.values(chrom, 0, chromsizes[chrom], numpy=True)\n", - " dn_dict[chrom] = np.nan_to_num(x).astype(np.float16) # half-precision makes things significantly smaller (less time to load)\n", - " print(chrom, time.time() - itime)\n", - " \n", - " np.save('dnase/{}/{}.npy'.format(ctype, chrom), dn_dict[chrom])\n", - " print(chrom, time.time() - itime)\n", - " dnases[ctype] = dn_dict\n", - "\n", - "for ctype in dnases:\n", - " itime = time.time()\n", - " print(ctype)\n", - " dn_dict = dnases[ctype]\n", - " \n", - " # Save as npz archive\n", - " np.savez_compressed('codalab_archive/{}_dnase'.format(ctype), **dn_dict)\n", - " print(time.time() - itime)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.5" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/dataset_preprocessing/encode/prep_metadata_labels.py b/dataset_preprocessing/encode/prep_metadata_labels.py index 8d1fc537..d6bd7b66 100644 --- a/dataset_preprocessing/encode/prep_metadata_labels.py +++ b/dataset_preprocessing/encode/prep_metadata_labels.py @@ -59,7 +59,7 @@ def write_label_bigwigs( sep='\t', header=None, index_col=None, names=['chr', 'start', 'stop', 'y', 'celltype'] ) - # Write the binned labels to bigwig files - genome-wide labels + # Write the binned labels to bigwig files, genome-wide labels chromsizes_list = [(k, v) for k, v in chrom_sizes.items()] for ct in celltypes: ct_labels_bw_path = _data_dir + "labels/{}/{}_{}.bigwig".format( @@ -134,13 +134,12 @@ def write_metadata_products( if __name__ == '__main__': - tf_name = 'JUND' tfs_to_celltypes = { - 'MAX': ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562', 'A549', 'GM12878', 'liver'], - 'REST': ['H1-hESC', 'HeLa-S3', 'HepG2', 'MCF-7', 'Panc1', 'liver'], + 'MAX': ['H1-hESC', 'HCT116', 'HeLa-S3', 'HepG2', 'K562', 'A549', 'GM12878', 'liver'], 'JUND': ['HCT116', 'HeLa-S3', 'HepG2', 'K562', 'MCF-7', 'liver'] } - all_celltypes = tfs_to_celltypes[tf_name] - write_label_bigwigs([x for x in all_celltypes if x != 'liver'], tf_name=tf_name) - write_label_bigwigs(['liver'], train_suffix='train_wc.labels.tsv.gz', val_suffix='test.labels.tsv.gz', tf_name=tf_name) - write_metadata_products(all_celltypes, tf_name=tf_name) + for tf_name in tfs_to_celltypes: + all_celltypes = tfs_to_celltypes[tf_name] + write_label_bigwigs([x for x in all_celltypes if x != 'liver'], tf_name=tf_name) + write_label_bigwigs(['liver'], train_suffix='train_wc.labels.tsv.gz', val_suffix='test.labels.tsv.gz', tf_name=tf_name) + write_metadata_products(all_celltypes, tf_name=tf_name) diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index 57bd3298..68cbe8a3 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -118,9 +118,7 @@ 'val_metric_decreasing': False, 'optimizer': 'Adam', 'scheduler': 'MultiStepLR', - 'scheduler_kwargs': {'milestones':[3,6], 'gamma': 0.1}, # used to be 6, 9, with 12 epochs - # 'scheduler': 'linear_schedule_with_warmup', - # 'scheduler_kwargs': {'num_warmup_steps': 800}, # about 160 minibatches per epoch + 'scheduler_kwargs': {'milestones':[3,6], 'gamma': 0.1}, 'batch_size': 128, 'lr': 1e-3, 'weight_decay': 1e-4, diff --git a/examples/configs/supported.py b/examples/configs/supported.py index 1a6b0662..a1ba51be 100644 --- a/examples/configs/supported.py +++ b/examples/configs/supported.py @@ -4,7 +4,7 @@ # metrics from wilds.common.metrics.loss import ElementwiseLoss, Loss, MultiTaskLoss -from wilds.common.metrics.all_metrics import Accuracy, MultiTaskAccuracy, MSE, multiclass_logits_to_pred, binary_logits_to_pred, MultiTaskAveragePrecision, MultiTaskPREven +from wilds.common.metrics.all_metrics import Accuracy, MultiTaskAccuracy, MSE, multiclass_logits_to_pred, binary_logits_to_pred, MultiTaskAveragePrecision losses = { 'cross_entropy': ElementwiseLoss(loss_fn=nn.CrossEntropyLoss(reduction='none')), From cc937f70c024ec14bf6d27a04789e71cb379ec6e Mon Sep 17 00:00:00 2001 From: aikanor Date: Sun, 11 Jul 2021 02:20:26 -0700 Subject: [PATCH 219/244] code cleanup --- dataset_preprocessing/encode/README.md | 7 ++-- .../encode/prep_accessibility.py | 32 +------------------ 2 files changed, 3 insertions(+), 36 deletions(-) diff --git a/dataset_preprocessing/encode/README.md b/dataset_preprocessing/encode/README.md index 1d045e07..598177bd 100644 --- a/dataset_preprocessing/encode/README.md +++ b/dataset_preprocessing/encode/README.md @@ -11,7 +11,7 @@ 3. Download the DNase accessibility data. This consists of whole-genome DNase files in bigwig format from https://guanfiles.dcmb.med.umich.edu/Leopard/dnase_bigwig/. These are saved with filename `DNASE..fc.signal.bigwig`. -4. Run `prep_accessibility.py`. +4. Run `python prep_accessibility.py`. 5. Download the labels from the challenge into a label directory `labels/` created for this purpose: - The training chromosome labels for the challenge's training cell types from https://www.synapse.org/#!Synapse:syn7413983 for the relevant transcription factor ( https://www.synapse.org/#!Synapse:syn7415202 for the TF MAX, , downloaded as MAX.train.labels.tsv.gz ). @@ -19,8 +19,5 @@ - The validation chromosome labels for the challenge's training cell types from https://www.synapse.org/#!Synapse:syn8441154 for the relevant transcription factor ( https://www.synapse.org/#!Synapse:syn8442103 for the TF MAX, downloaded as MAX.val.labels.tsv.gz ). - The validation chromosome labels for the challenge's evaluation cell type (liver) from https://www.synapse.org/#!Synapse:syn8442975 for the relevant transcription factor ( https://www.synapse.org/#!Synapse:syn8443021 for the TF MAX, downloaded as MAX.test.labels.tsv.gz ). -6. Run `prep_metadata_labels.py`. +6. Run `python prep_metadata_labels.py`. - -#### Instructions to run on Codalab bundle -7. diff --git a/dataset_preprocessing/encode/prep_accessibility.py b/dataset_preprocessing/encode/prep_accessibility.py index 6afcab2c..62228dba 100644 --- a/dataset_preprocessing/encode/prep_accessibility.py +++ b/dataset_preprocessing/encode/prep_accessibility.py @@ -71,17 +71,11 @@ def anchor(input_data, sample, ref): # input 1d array degree = 1 # degree of the fitting polynomial num = 10 # number of positions for extrapolate f1 = np.poly1d(np.polyfit(sample[-num:],ref[-num:],degree)) -# f2=np.poly1d(np.polyfit(sample[:num],ref[:num],degree)) output[input_data > sample[-1]] = f1(input_data[input_data > sample[-1]]) -# output[input_data Date: Sun, 11 Jul 2021 12:12:41 -0700 Subject: [PATCH 220/244] evaluate globalwheat --- examples/evaluate.py | 40 ++++++++++++++++++++++++---------------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/examples/evaluate.py b/examples/evaluate.py index e7c4cacc..199271ac 100644 --- a/examples/evaluate.py +++ b/examples/evaluate.py @@ -1,9 +1,9 @@ import argparse import json import os -import sys import urllib.request from ast import literal_eval +from operator import itemgetter from typing import Any, Dict, List from urllib.parse import urlparse @@ -12,6 +12,7 @@ from wilds import benchmark_datasets from wilds import get_dataset +from wilds.datasets.globalwheat_dataset import GlobalWheatDataset from wilds.datasets.wilds_dataset import WILDSDataset, WILDSSubset @@ -83,10 +84,12 @@ def get_prediction_file( ) -> str: run_id = f"{dataset_name}_split:{split}_{replicate}" for file in os.listdir(predictions_dir): - if file.startswith(run_id) and file.endswith(".csv"): + if file.startswith(run_id) and ( + file.endswith(".csv") or file.endswith(".pth") + ): return file raise FileNotFoundError( - f"Could not find CSV prediction file that starts with {run_id}." + f"Could not find CSV or pth prediction file that starts with {run_id}." ) def get_metrics(dataset_name: str) -> List[str]: @@ -140,11 +143,8 @@ def get_metrics(dataset_name: str) -> List[str]: ) full_path = os.path.join(predictions_dir, predictions_file) predicted_labels: List[Any] = get_predictions(full_path) - predicted_labels_tensor: torch.Tensor = torch.from_numpy( - np.array(predicted_labels) - ) metric_results: Dict[str, float] = evaluate_replicate( - wilds_dataset, split, predicted_labels_tensor + wilds_dataset, split, predicted_labels ) for metric in metrics: replicates_results[split][metric].append(metric_results[metric]) @@ -185,15 +185,21 @@ def evaluate_replicate( """ # Dataset will only be downloaded if it does not exist subset: WILDSSubset = dataset.get_subset(split) - true_labels: torch.Tensor = subset.y_array metadata: torch.Tensor = subset.metadata_array - # predicted_labels.resize_(true_labels.shape) - if predicted_labels.shape != true_labels.shape: - predicted_labels.unsqueeze_(-1) - return dataset.eval(predicted_labels, true_labels, metadata)[0] + + if type(dataset) == GlobalWheatDataset: + predicted_labels = predicted_labels + true_labels = list(itemgetter(*subset.indices)(subset.dataset.y_array)) + else: + true_labels: torch.Tensor = subset.y_array + if predicted_labels.shape != true_labels.shape: + predicted_labels.unsqueeze_(-1) + + results_dict, results_str = dataset.eval(predicted_labels, true_labels, metadata) + return results_dict -def get_predictions(path: str) -> List[Any]: +def get_predictions(path: str) -> torch.tensor: """ Extract out the predictions from the file at path. @@ -203,7 +209,9 @@ def get_predictions(path: str) -> List[Any]: Return: List of predictions. """ - if is_path_url(path): + if path.endswith(".pth"): + return torch.load(path) + elif is_path_url(path): data = urllib.request.urlopen(path) else: file = open(path, mode="r") @@ -211,7 +219,7 @@ def get_predictions(path: str) -> List[Any]: file.close() predicted_labels = [literal_eval(line.rstrip()) for line in data if line.rstrip()] - return predicted_labels + return torch.from_numpy(np.array(predicted_labels)) def is_path_url(path: str) -> bool: @@ -243,7 +251,7 @@ def main(): parser.add_argument( "predictions_dir", type=str, - help="Path to prediction CSV files.", + help="Path to prediction CSV or pth files.", ) parser.add_argument( "output_dir", From 40a5c29cecc3fbf2203c6a47765d657f01f8fda9 Mon Sep 17 00:00:00 2001 From: Tony Lee Date: Sun, 11 Jul 2021 14:36:52 -0700 Subject: [PATCH 221/244] fix type checking --- examples/evaluate.py | 47 ++++++++++++++++++++++++++------------------ 1 file changed, 28 insertions(+), 19 deletions(-) diff --git a/examples/evaluate.py b/examples/evaluate.py index 199271ac..9a150e07 100644 --- a/examples/evaluate.py +++ b/examples/evaluate.py @@ -4,7 +4,8 @@ import urllib.request from ast import literal_eval from operator import itemgetter -from typing import Any, Dict, List +from typing import Dict, List +from urllib.parse import urlparse from urllib.parse import urlparse import numpy as np @@ -142,10 +143,17 @@ def get_metrics(dataset_name: str) -> List[str]: f"Processing split={split}, replicate={replicate}, predictions_file={predictions_file}..." ) full_path = os.path.join(predictions_dir, predictions_file) - predicted_labels: List[Any] = get_predictions(full_path) - metric_results: Dict[str, float] = evaluate_replicate( - wilds_dataset, split, predicted_labels - ) + + # GlobalWheat's predictions are a list of dictionaries, so it has to be handle separately + if dataset_name == "wheat": + metric_results: Dict[str, float] = evaluate_replicate_for_globalwheat( + wilds_dataset, split, full_path + ) + else: + predicted_labels: torch.Tensor = get_predictions(full_path) + metric_results = evaluate_replicate( + wilds_dataset, split, predicted_labels + ) for metric in metrics: replicates_results[split][metric].append(metric_results[metric]) @@ -186,20 +194,23 @@ def evaluate_replicate( # Dataset will only be downloaded if it does not exist subset: WILDSSubset = dataset.get_subset(split) metadata: torch.Tensor = subset.metadata_array + true_labels = subset.y_array + if predicted_labels.shape != true_labels.shape: + predicted_labels.unsqueeze_(-1) + return dataset.eval(predicted_labels, true_labels, metadata)[0] - if type(dataset) == GlobalWheatDataset: - predicted_labels = predicted_labels - true_labels = list(itemgetter(*subset.indices)(subset.dataset.y_array)) - else: - true_labels: torch.Tensor = subset.y_array - if predicted_labels.shape != true_labels.shape: - predicted_labels.unsqueeze_(-1) - results_dict, results_str = dataset.eval(predicted_labels, true_labels, metadata) - return results_dict +def evaluate_replicate_for_globalwheat( + dataset: WILDSDataset, split: str, path_to_predictions: str +) -> Dict[str, float]: + predicted_labels = torch.load(path_to_predictions) + subset: WILDSSubset = dataset.get_subset(split) + metadata: torch.Tensor = subset.metadata_array + true_labels = list(itemgetter(*subset.indices)(subset.dataset.y_array)) + return dataset.eval(predicted_labels, true_labels, metadata)[0] -def get_predictions(path: str) -> torch.tensor: +def get_predictions(path: str) -> torch.Tensor: """ Extract out the predictions from the file at path. @@ -207,11 +218,9 @@ def get_predictions(path: str) -> torch.tensor: path (str): Path to the file that has the predicted labels. Can be a URL. Return: - List of predictions. + Tensor representing predictions """ - if path.endswith(".pth"): - return torch.load(path) - elif is_path_url(path): + if is_path_url(path): data = urllib.request.urlopen(path) else: file = open(path, mode="r") From 2cd48572a67074c4590c2041a8ad0f2ddc21a50b Mon Sep 17 00:00:00 2001 From: Tony Lee Date: Sun, 11 Jul 2021 14:38:10 -0700 Subject: [PATCH 222/244] fix dataset check --- examples/evaluate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/evaluate.py b/examples/evaluate.py index 9a150e07..152887a1 100644 --- a/examples/evaluate.py +++ b/examples/evaluate.py @@ -145,7 +145,7 @@ def get_metrics(dataset_name: str) -> List[str]: full_path = os.path.join(predictions_dir, predictions_file) # GlobalWheat's predictions are a list of dictionaries, so it has to be handle separately - if dataset_name == "wheat": + if dataset_name == "globalwheat": metric_results: Dict[str, float] = evaluate_replicate_for_globalwheat( wilds_dataset, split, full_path ) From ef7e4944a8b138c494d787ea6c357cb1fde28302 Mon Sep 17 00:00:00 2001 From: Tony Lee Date: Sun, 11 Jul 2021 14:43:33 -0700 Subject: [PATCH 223/244] use acc_avg or rxrx1 evaluation --- examples/evaluate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/evaluate.py b/examples/evaluate.py index 152887a1..3f2bde63 100644 --- a/examples/evaluate.py +++ b/examples/evaluate.py @@ -113,7 +113,7 @@ def get_metrics(dataset_name: str) -> List[str]: elif "globalwheat" == dataset_name: return ["detection_acc_avg_dom"] elif "rxrx1" == dataset_name: - return ["acc_avg", "acc_wg"] + return ["acc_avg"] else: raise ValueError(f"Invalid dataset: {dataset_name}") @@ -144,7 +144,7 @@ def get_metrics(dataset_name: str) -> List[str]: ) full_path = os.path.join(predictions_dir, predictions_file) - # GlobalWheat's predictions are a list of dictionaries, so it has to be handle separately + # GlobalWheat's predictions are a list of dictionaries, so it has to be handled separately if dataset_name == "globalwheat": metric_results: Dict[str, float] = evaluate_replicate_for_globalwheat( wilds_dataset, split, full_path From 507120613180be208887562b23f02d8a65e018a0 Mon Sep 17 00:00:00 2001 From: Tony Lee Date: Sun, 11 Jul 2021 14:44:50 -0700 Subject: [PATCH 224/244] cleanup --- examples/evaluate.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/evaluate.py b/examples/evaluate.py index 3f2bde63..85ba26b7 100644 --- a/examples/evaluate.py +++ b/examples/evaluate.py @@ -6,14 +6,12 @@ from operator import itemgetter from typing import Dict, List from urllib.parse import urlparse -from urllib.parse import urlparse import numpy as np import torch from wilds import benchmark_datasets from wilds import get_dataset -from wilds.datasets.globalwheat_dataset import GlobalWheatDataset from wilds.datasets.wilds_dataset import WILDSDataset, WILDSSubset From 32767af3ddde35f592ef4796ff10edb1deb899e9 Mon Sep 17 00:00:00 2001 From: aikanor Date: Sun, 11 Jul 2021 15:49:19 -0700 Subject: [PATCH 225/244] cleanup --- wilds/datasets/encode_dataset.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/wilds/datasets/encode_dataset.py b/wilds/datasets/encode_dataset.py index b2301152..877a9380 100644 --- a/wilds/datasets/encode_dataset.py +++ b/wilds/datasets/encode_dataset.py @@ -301,10 +301,7 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' self._y_array = torch.tensor(np.load( self._data_dir + '/labels/{}/metadata_y.npy'.format(self._transcription_factor))) - # ~10% of the dataset has ambiguous labels - # i.e., we can't tell if there is a binding event or not. - # This typically happens at the flanking regions of peaks. - # For our purposes, we will ignore these ambiguous labels during training and eval. + # ~10% of the dataset has ambiguous labels, i.e., we can't tell if there is a binding event or not. This typically happens at the flanking regions of peaks. For our purposes, we will ignore these ambiguous labels during training and eval. self.y_array[self.y_array == 0.5] = float('nan') self._split_array = -1 * np.ones(self._metadata_df.shape[0]).astype(int) @@ -396,7 +393,8 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' self._metric = MultiTaskAveragePrecision() super().__init__(root_dir, download, split_scheme) - + + def get_input(self, idx, window_size=12800): """ Returns x for a given idx in metadata_array, which has been filtered to only take windows with the desired stride. @@ -417,7 +415,8 @@ def get_input(self, idx, window_size=12800): [seq_this, dnase_this] ).T) - + + def eval(self, y_pred, y_true, metadata): return self.standard_group_eval( self._metric, From 378ec0e67e8558c2ccc9dc73b16c40b7a01c0984 Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Sun, 11 Jul 2021 19:22:49 -0700 Subject: [PATCH 226/244] Added GlobalWheat splits --- wilds/datasets/globalwheat_dataset.py | 74 ++++++++++++++++++--------- 1 file changed, 51 insertions(+), 23 deletions(-) diff --git a/wilds/datasets/globalwheat_dataset.py b/wilds/datasets/globalwheat_dataset.py index 1c4dbf49..3e44dc70 100644 --- a/wilds/datasets/globalwheat_dataset.py +++ b/wilds/datasets/globalwheat_dataset.py @@ -108,10 +108,13 @@ class GlobalWheatDataset(WILDSDataset): """ The GlobalWheat-WILDS wheat head localization dataset. - This is a modified version of the original Global Wheat Head Dataset 2021. + This is a modified version of the original Global Wheat Head Dataset 2021. Supported `split_scheme`: - 'official' + - 'official_with_subsampled_test' + - 'fixed-test' + - 'mixed-train' Input (x): 1024 x 1024 RGB images of wheat field canopy starting from anthesis (flowering) to ripening. Output (y): @@ -162,43 +165,68 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' self._is_classification = False self._y_size = None self._n_classes = 1 - self._split_scheme = split_scheme - # Get filenames + self._split_dict = { + 'train': 0, + 'val': 1, + 'test': 2, + } + self._split_names = { + 'train': 'Train', + 'val': 'Validation (OOD)', + 'test':'Test (OOD)', + } + + data_dfs = {} if split_scheme == "official": - train_data_df = pd.read_csv(self.root / f'official_train.csv') - val_data_df = pd.read_csv(self.root / f'official_val.csv') - test_data_df = pd.read_csv(self.root / f'official_test.csv') + data_dfs['train'] = pd.read_csv(self.root / f'official_train.csv') + data_dfs['val'] = pd.read_csv(self.root / f'official_val.csv') + data_dfs['test'] = pd.read_csv(self.root / f'official_test.csv') + data_dfs['id_val'] = pd.read_csv(self.root / f'fixed_train_val.csv') + data_dfs['id_test'] = pd.read_csv(self.root / f'fixed_train_test.csv') + self._split_dict = { + 'train': 0, + 'val': 1, + 'test': 2, + 'id_val': 3, + 'id_test': 4, + } + self._split_names = { + 'train': 'Train', + 'val': 'Validation (OOD)', + 'test':'Test (OOD)', + 'id_val': 'Validation (ID)', + 'id_test': 'Test (ID)' + } elif split_scheme == "official_with_subsampled_test": - train_data_df = pd.read_csv(self.root / f'official_train.csv') - val_data_df = pd.read_csv(self.root / f'official_val.csv') - test_data_df = pd.read_csv(self.root / f'fixed_test_test.csv') + data_dfs['train'] = pd.read_csv(self.root / f'official_train.csv') + data_dfs['val'] = pd.read_csv(self.root / f'official_val.csv') + data_dfs['test'] = pd.read_csv(self.root / f'fixed_test_test.csv') - elif split_scheme == "in-dist": - train_data_df = pd.read_csv(self.root / f'in_dist_train.csv') - val_data_df = pd.read_csv(self.root / f'official_val.csv') - test_data_df = pd.read_csv(self.root / f'in_dist_test.csv') + elif split_scheme == "fixed_test": + data_dfs['train'] = pd.read_csv(self.root / f'fixed_test_train.csv') + data_dfs['val'] = pd.read_csv(self.root / f'official_val.csv') + data_dfs['test'] = pd.read_csv(self.root / f'fixed_test_test.csv') - elif split_scheme == "fixed-train": - train_data_df = pd.read_csv(self.root / f'fixed_train_train.csv') - val_data_df = pd.read_csv(self.root / f'fixed_train_val.csv') - test_data_df = pd.read_csv(self.root / f'fixed_train_test.csv') + elif split_scheme == "mixed_train": + data_dfs['train'] = pd.read_csv(self.root / f'mixed_train_train.csv') + data_dfs['val'] = pd.read_csv(self.root / f'official_val.csv') + data_dfs['test'] = pd.read_csv(self.root / f'mixed_train_test.csv') - elif split_scheme == "fixed-test": - train_data_df = pd.read_csv(self.root / f'fixed_test_train.csv') - val_data_df = pd.read_csv(self.root / f'official_val.csv') - test_data_df = pd.read_csv(self.root / f'fixed_test_test.csv') + else: + raise ValueError(f'Split scheme {self.split_scheme} not recognized') self._image_array = [] self._split_array, self._y_array, self._metadata_array = [], [], [] - for i, df in enumerate([train_data_df, val_data_df, test_data_df]): + for split_name, split_idx in self._split_dict.items(): + df = data_dfs[split_name] self._image_array.extend(list(df['image_name'].values)) boxes_string = list(df['BoxesString'].values) all_boxes = [GlobalWheatDataset._decode_string(box_string) for box_string in boxes_string] - self._split_array.extend([i] * len(all_boxes)) + self._split_array.extend([split_idx] * len(all_boxes)) labels = [{ "boxes": torch.stack([ From 49d093fbf8a2a4174a7f80d94c33e511a392b161 Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Sun, 11 Jul 2021 19:24:56 -0700 Subject: [PATCH 227/244] New GlobalWheat bundle URL --- wilds/datasets/globalwheat_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/wilds/datasets/globalwheat_dataset.py b/wilds/datasets/globalwheat_dataset.py index 3e44dc70..0aecc3e3 100644 --- a/wilds/datasets/globalwheat_dataset.py +++ b/wilds/datasets/globalwheat_dataset.py @@ -151,8 +151,8 @@ class GlobalWheatDataset(WILDSDataset): _dataset_name = 'globalwheat' _versions_dict = { '1.0': { - 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x03b0584cb00d4ea987aa3269aa2fd2b4/contents/blob/', - 'compressed_size': 10_286_874_624} + 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x443fbcb18eeb4f80b5ea4a9f77795168/contents/blob/', + 'compressed_size': None} } def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'): From 8ad9fa533c051278fd4433e69f58830e4f0ab523 Mon Sep 17 00:00:00 2001 From: Tony Lee Date: Mon, 12 Jul 2021 12:08:13 -0700 Subject: [PATCH 228/244] update with best hyperparameter config for globalwheat --- examples/configs/datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index 014e91a0..02c89564 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -313,7 +313,7 @@ 'scheduler': None, 'batch_size': 4, 'lr': 1e-5, - 'weight_decay': 0, + 'weight_decay': 1e-3, 'n_epochs': 10, 'loader_kwargs': { 'num_workers': 1, From e628700f37ad7b3b118083bdba7ab371217ab94e Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Tue, 13 Jul 2021 09:51:25 -0700 Subject: [PATCH 229/244] GlobalWheat size --- wilds/datasets/globalwheat_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wilds/datasets/globalwheat_dataset.py b/wilds/datasets/globalwheat_dataset.py index 0aecc3e3..4d9a1400 100644 --- a/wilds/datasets/globalwheat_dataset.py +++ b/wilds/datasets/globalwheat_dataset.py @@ -152,7 +152,7 @@ class GlobalWheatDataset(WILDSDataset): _versions_dict = { '1.0': { 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x443fbcb18eeb4f80b5ea4a9f77795168/contents/blob/', - 'compressed_size': None} + 'compressed_size': 10_286_120_960} } def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'): From 2f1a670ddb1b97f30921e7c1087ea5d9db74de79 Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Tue, 13 Jul 2021 14:30:49 -0700 Subject: [PATCH 230/244] encode size --- wilds/datasets/encode_dataset.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/wilds/datasets/encode_dataset.py b/wilds/datasets/encode_dataset.py index 877a9380..f3009d72 100644 --- a/wilds/datasets/encode_dataset.py +++ b/wilds/datasets/encode_dataset.py @@ -102,6 +102,10 @@ class EncodeDataset(WILDSDataset): ENCODE dataset of transcription factor binding sites. This is a subset of the dataset from the ENCODE-DREAM in vivo Transcription Factor Binding Site Prediction Challenge. + Note: The first time this dataset is used, it will run some one-off preprocessing scripts that will take some additional time. + These scripts might cause a race condition if multiple jobs are started in parallel, + so we recommend running a single job the first time you use this dataset. + Input (x): 12800-base-pair regions of sequence with a quantified chromatin accessibility readout. @@ -121,7 +125,7 @@ class EncodeDataset(WILDSDataset): _versions_dict = { '1.0': { 'download_url': 'https://worksheets.codalab.org/rest/bundles/0x9c282b6e9082440f9dcd61bb605c1eab/contents/blob/', - 'compressed_size': None}} + 'compressed_size': 7_692_640_256}} def __init__(self, version=None, root_dir='data', download=False, split_scheme='official'): itime = time.time() @@ -393,8 +397,8 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' self._metric = MultiTaskAveragePrecision() super().__init__(root_dir, download, split_scheme) - - + + def get_input(self, idx, window_size=12800): """ Returns x for a given idx in metadata_array, which has been filtered to only take windows with the desired stride. @@ -415,8 +419,8 @@ def get_input(self, idx, window_size=12800): [seq_this, dnase_this] ).T) - - + + def eval(self, y_pred, y_true, metadata): return self.standard_group_eval( self._metric, From 3eae8f80151f8560d9bceb68e85258b36362f0ab Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Tue, 13 Jul 2021 14:37:54 -0700 Subject: [PATCH 231/244] accessibility edit --- .../encode/prep_accessibility.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/dataset_preprocessing/encode/prep_accessibility.py b/dataset_preprocessing/encode/prep_accessibility.py index 62228dba..218e4fc5 100644 --- a/dataset_preprocessing/encode/prep_accessibility.py +++ b/dataset_preprocessing/encode/prep_accessibility.py @@ -131,19 +131,6 @@ def dnase_normalize( if __name__ == '__main__': train_chroms = ['chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr10', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX'] - ch_train_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'K562', 'A549', 'GM12878'] - ch_val_celltype = ['HepG2'] - ch_test_celltype = ['liver'] - ref_celltypes = ch_train_celltypes - all_celltypes = ch_train_celltypes + ch_val_celltype + ch_test_celltype + all_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'K562', 'A549', 'GM12878', 'HepG2', 'liver'] for ct in all_celltypes: qn_sample_to_array([ct], input_chroms=train_chroms) - - """ - # Create normalized bigwigs for OOD validation split. - for ct in all_celltypes: - dnase_normalize(ct, ref_celltypes) - # Create normalized bigwig for ID validation split. - for ct in ch_test_celltype: - dnase_normalize(ct, ch_test_celltype, out_fname = 'norm_id') - """ From 4144ef56321bb9cfb4bc958f44facc3d027d4f57 Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Tue, 13 Jul 2021 14:41:59 -0700 Subject: [PATCH 232/244] amazon hparams --- examples/configs/datasets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index 2e74547f..164a8242 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -10,7 +10,7 @@ 'lr': 1e-5, 'weight_decay': 0.01, 'n_epochs': 3, - 'n_groups_per_batch': 4, + 'n_groups_per_batch': 2, 'irm_lambda': 1.0, 'coral_penalty_weight': 1.0, 'loader_kwargs': { @@ -113,7 +113,7 @@ 'val_metric_decreasing': False, 'optimizer': 'Adam', 'scheduler': 'MultiStepLR', - 'scheduler_kwargs': {'milestones':[3,6], 'gamma': 0.1}, + 'scheduler_kwargs': {'milestones':[3,6], 'gamma': 0.1}, 'batch_size': 128, 'lr': 1e-3, 'weight_decay': 1e-4, From 8ce4ce213ea94bae14f2fd9e398e4adc95cca0e6 Mon Sep 17 00:00:00 2001 From: aikanor Date: Tue, 13 Jul 2021 16:46:55 -0700 Subject: [PATCH 233/244] encode accessibility edits --- dataset_preprocessing/encode/README.md | 8 +- .../encode/prep_accessibility.py | 81 ------------------- 2 files changed, 5 insertions(+), 84 deletions(-) diff --git a/dataset_preprocessing/encode/README.md b/dataset_preprocessing/encode/README.md index 598177bd..db1ef122 100644 --- a/dataset_preprocessing/encode/README.md +++ b/dataset_preprocessing/encode/README.md @@ -5,13 +5,15 @@ #### Instructions to create Codalab bundle +All file paths are taken to be relative to `` in the code. + 1. Download the human genome sequence (hg19 assembly) in FASTA format from http://hgdownload.cse.ucsc.edu/goldenpath/hg19/bigZips/hg19.fa.gz and extract it into `SEQUENCE_PATH`. -2. Run `python prep_sequence.py --seq_path SEQUENCE_PATH --output_dir OUTPUT_DIR` to write the fasta file found in `SEQUENCE_PATH` to a numpy array archive in `OUTPUT_DIR`. +2. Run `python prep_sequence.py --seq_path SEQUENCE_PATH --output_dir OUTPUT_DIR` to write the fasta file found in `SEQUENCE_PATH` to a numpy array archive in `OUTPUT_DIR`. `OUTPUT_DIR` is taken to be `sequence.npz` in the code. -3. Download the DNase accessibility data. This consists of whole-genome DNase files in bigwig format from https://guanfiles.dcmb.med.umich.edu/Leopard/dnase_bigwig/. These are saved with filename `DNASE..fc.signal.bigwig`. +3. Download the DNase accessibility data. This consists of whole-genome DNase files in bigwig format from https://guanfiles.dcmb.med.umich.edu/Leopard/dnase_bigwig/. These are saved with filename `DNASE..fc.signal.bigwig` in the code. -4. Run `python prep_accessibility.py`. +4. Run `python prep_accessibility.py`. This writes quantile-normalized samples of each bigwig file to `qn..npy`. 5. Download the labels from the challenge into a label directory `labels/` created for this purpose: - The training chromosome labels for the challenge's training cell types from https://www.synapse.org/#!Synapse:syn7413983 for the relevant transcription factor ( https://www.synapse.org/#!Synapse:syn7415202 for the TF MAX, , downloaded as MAX.train.labels.tsv.gz ). diff --git a/dataset_preprocessing/encode/prep_accessibility.py b/dataset_preprocessing/encode/prep_accessibility.py index 218e4fc5..eb8a565e 100644 --- a/dataset_preprocessing/encode/prep_accessibility.py +++ b/dataset_preprocessing/encode/prep_accessibility.py @@ -48,87 +48,6 @@ def qn_sample_to_array( np.save(data_pfx + "qn.{}.npy".format('.'.join(input_celltypes)), sample) -# quantile normalization via numpy inter/extra-polation -def anchor(input_data, sample, ref): # input 1d array - sample.sort() - ref.sort() - # 0. create the mapping function - index = np.array(np.where(np.diff(sample) != 0)) + 1 - index = index.flatten() - x = np.concatenate((np.zeros(1), sample[index])) # domain - y = np.zeros(len(x)) # codomain - for i in np.arange(0,len(index)-1, 1): - start = index[i] - end = index[i+1] - y[i+1] = np.mean(ref[start:end]) - i += 1 - start = index[i] - end = len(ref) - y[i+1] = np.mean(ref[start:end]) - # 1. interpolate - output = np.interp(input_data, x, y) - # 2. extrapolate - degree = 1 # degree of the fitting polynomial - num = 10 # number of positions for extrapolate - f1 = np.poly1d(np.polyfit(sample[-num:],ref[-num:],degree)) - output[input_data > sample[-1]] = f1(input_data[input_data > sample[-1]]) - return output - - -def wrap_anchor(signal, sample, ref): - ## 1.format as bigwig first - x = signal - z = np.concatenate(([0],x,[0])) # pad two zeroes - # find boundary - starts = np.where(np.diff(z) != 0)[0] - ends = starts[1:] - starts = starts[:-1] - vals = x[starts] - if starts[0] != 0: - ends = np.concatenate(([starts[0]],ends)) - starts = np.concatenate(([0],starts)) - vals = np.concatenate(([0],vals)) - if ends[-1] != len(signal): - starts = np.concatenate((starts,[ends[-1]])) - ends = np.concatenate((ends,[len(signal)])) - vals = np.concatenate((vals,[0])) - - ## 2.then quantile normalization - vals_anchored = anchor(vals, sample, ref) - return vals_anchored, starts, ends - - -def dnase_normalize( - input_bw_celltype, - ref_celltypes, - out_fname = 'norm', - data_pfx = '/users/abalsubr/wilds/examples/data/encode_v1.0/' -): - itime = time.time() - sample = np.load(data_pfx + "qn.{}.npy".format(input_bw_celltype)) - ref = np.zeros(len(sample)) - for ct in ref_celltypes: - ref += (1.0/len(ref_celltypes))*np.load(data_pfx + "qn.{}.npy".format(ct)) - - chromsizes_list = [(k, v) for k, v in chrom_sizes.items()] - bw_output = pyBigWig.open(data_pfx + 'DNase.{}.{}.bigwig'.format( - input_bw_celltype, out_fname), 'w') - bw_output.addHeader(chromsizes_list) - - for the_chr in chrom_sizes: - signal = np.zeros(chrom_sizes[the_chr]) - bw = pyBigWig.open(data_pfx + 'DNASE.{}.fc.signal.bigwig'.format(input_bw_celltype)) - signal += np.nan_to_num(np.array(bw.values(the_chr, 0, chrom_sizes[the_chr]))) - bw.close() - vals_anchored, starts, ends = wrap_anchor(signal, sample, ref) - # write normalized dnase file. - chroms = np.array([the_chr] * len(vals_anchored)) - bw_output.addEntries(chroms, starts, ends=ends, values=vals_anchored) - print(input_bw_celltype, the_chr, time.time() - itime) - - bw_output.close() - - if __name__ == '__main__': train_chroms = ['chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr10', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX'] all_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'K562', 'A549', 'GM12878', 'HepG2', 'liver'] From 6b5084d8c2142250cdf3b34914216cc2a9a90cfb Mon Sep 17 00:00:00 2001 From: aikanor Date: Wed, 14 Jul 2021 01:15:56 -0700 Subject: [PATCH 234/244] encode accessibility edits --- dataset_preprocessing/encode/README.md | 12 ++++++------ .../encode/prep_accessibility.py | 17 ++++++++--------- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/dataset_preprocessing/encode/README.md b/dataset_preprocessing/encode/README.md index db1ef122..66053fcc 100644 --- a/dataset_preprocessing/encode/README.md +++ b/dataset_preprocessing/encode/README.md @@ -5,18 +5,18 @@ #### Instructions to create Codalab bundle -All file paths are taken to be relative to `` in the code. +Here are instructions to reproduce the Codalab bundle, in a directory path `BUNDLE_ROOT_DIRECTORY`. 1. Download the human genome sequence (hg19 assembly) in FASTA format from http://hgdownload.cse.ucsc.edu/goldenpath/hg19/bigZips/hg19.fa.gz and extract it into `SEQUENCE_PATH`. -2. Run `python prep_sequence.py --seq_path SEQUENCE_PATH --output_dir OUTPUT_DIR` to write the fasta file found in `SEQUENCE_PATH` to a numpy array archive in `OUTPUT_DIR`. `OUTPUT_DIR` is taken to be `sequence.npz` in the code. +2. Run `python prep_sequence.py --seq_path SEQUENCE_PATH --output_dir OUTPUT_DIR` to write the fasta file found in `SEQUENCE_PATH` to a numpy array archive in `OUTPUT_PATH`. (The dataset loader assumes `OUTPUT_PATH` to be `/sequence.npz`.) -3. Download the DNase accessibility data. This consists of whole-genome DNase files in bigwig format from https://guanfiles.dcmb.med.umich.edu/Leopard/dnase_bigwig/. These are saved with filename `DNASE..fc.signal.bigwig` in the code. +3. Download the DNase accessibility data. This consists of whole-genome DNase files in bigwig format from https://guanfiles.dcmb.med.umich.edu/Leopard/dnase_bigwig/. Save these to filenames `/DNASE..fc.signal.bigwig` in the code. -4. Run `python prep_accessibility.py`. This writes quantile-normalized samples of each bigwig file to `qn..npy`. +4. Run `python prep_accessibility.py`. This writes samples of each bigwig file to `/qn..npy`. These are used at runtime when the dataset loader is initialized, to perform quantile normalization on the DNase accessibility signals. -5. Download the labels from the challenge into a label directory `labels/` created for this purpose: - - The training chromosome labels for the challenge's training cell types from https://www.synapse.org/#!Synapse:syn7413983 for the relevant transcription factor ( https://www.synapse.org/#!Synapse:syn7415202 for the TF MAX, , downloaded as MAX.train.labels.tsv.gz ). +5. Download the labels from the challenge into a label directory `/labels/` created for this purpose: + - The training chromosome labels for the challenge's training cell types from https://www.synapse.org/#!Synapse:syn7413983 for the relevant transcription factor ( https://www.synapse.org/#!Synapse:syn7415202 for the TF MAX, downloaded as MAX.train.labels.tsv.gz ). - The training chromosome labels for the challenge's evaluation cell type (liver) from https://www.synapse.org/#!Synapse:syn8077511 for the relevant transcription factor ( https://www.synapse.org/#!Synapse:syn8077648 for the TF MAX, downloaded as MAX.train_wc.labels.tsv.gz ). - The validation chromosome labels for the challenge's training cell types from https://www.synapse.org/#!Synapse:syn8441154 for the relevant transcription factor ( https://www.synapse.org/#!Synapse:syn8442103 for the TF MAX, downloaded as MAX.val.labels.tsv.gz ). - The validation chromosome labels for the challenge's evaluation cell type (liver) from https://www.synapse.org/#!Synapse:syn8442975 for the relevant transcription factor ( https://www.synapse.org/#!Synapse:syn8443021 for the TF MAX, downloaded as MAX.test.labels.tsv.gz ). diff --git a/dataset_preprocessing/encode/prep_accessibility.py b/dataset_preprocessing/encode/prep_accessibility.py index eb8a565e..c4831bd8 100644 --- a/dataset_preprocessing/encode/prep_accessibility.py +++ b/dataset_preprocessing/encode/prep_accessibility.py @@ -4,7 +4,7 @@ import numpy as np import pyBigWig -# Human chromosomes in hg19 +# Human chromosomes in hg19, and their sizes in bp chrom_sizes = {'chr1': 249250621, 'chr10': 135534747, 'chr11': 135006516, 'chr12': 133851895, 'chr13': 115169878, 'chr14': 107349540, 'chr15': 102531392, 'chr16': 90354753, 'chr17': 81195210, 'chr18': 78077248, 'chr19': 59128983, 'chr2': 243199373, 'chr20': 63025520, 'chr21': 48129895, 'chr22': 51304566, 'chr3': 198022430, 'chr4': 191154276, 'chr5': 180915260, 'chr6': 171115067, 'chr7': 159138663, 'chr8': 146364022, 'chr9': 141213431, 'chrX': 155270560} @@ -14,18 +14,20 @@ def qn_sample_to_array( subsampling_ratio=1000, data_pfx = '/users/abalsubr/wilds/examples/data/encode_v1.0/' ): - itime = time.time() + """ + Compute and write distribution of DNase bigwigs corresponding to input celltypes. + """ if input_chroms is None: input_chroms = chrom_sizes.keys() qn_chrom_sizes = { k: chrom_sizes[k] for k in input_chroms } - # chromosome-specific subsampling seeds + # Initialize chromosome-specific seeds for subsampling chr_to_seed = {} i = 0 for the_chr in qn_chrom_sizes: chr_to_seed[the_chr] = i i += 1 - # subsampling; multiple replicates are added + # subsampling sample_len = np.ceil(np.array(list(qn_chrom_sizes.values()))/subsampling_ratio).astype(int) sample = np.zeros(sum(sample_len)) start = 0 @@ -40,16 +42,13 @@ def qn_sample_to_array( sample[start:(start+sample_len[j])] += (1.0/len(input_celltypes))*signal[index] start += sample_len[j] j += 1 - print(the_chr, ct, time.time() - itime) - - if np.any(np.isnan(sample)): - print('wtf! sample contains nan!') + print(the_chr, ct) sample.sort() np.save(data_pfx + "qn.{}.npy".format('.'.join(input_celltypes)), sample) if __name__ == '__main__': train_chroms = ['chr3', 'chr4', 'chr5', 'chr6', 'chr7', 'chr10', 'chr12', 'chr13', 'chr14', 'chr15', 'chr16', 'chr17', 'chr18', 'chr19', 'chr20', 'chr22', 'chrX'] - all_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'K562', 'A549', 'GM12878', 'HepG2', 'liver'] + all_celltypes = ['H1-hESC', 'HCT116', 'HeLa-S3', 'K562', 'A549', 'GM12878', 'MCF-7', 'HepG2', 'liver'] for ct in all_celltypes: qn_sample_to_array([ct], input_chroms=train_chroms) From 50e3b982aff42281c6d426b5182f0c640c2d0c54 Mon Sep 17 00:00:00 2001 From: Tony Lee Date: Wed, 14 Jul 2021 17:42:15 -0700 Subject: [PATCH 235/244] update README --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index ce05cae2..6b922b5a 100644 --- a/README.md +++ b/README.md @@ -103,7 +103,9 @@ These are the sizes of each of our datasets, as well as their approximate time t |-----------------|----------|--------------------|-------------------|-------------------------| | iwildcam | Image | 11 | 25 | 7 | | camelyon17 | Image | 10 | 15 | 2 | +| rxrx1 | Image | 6.9 | 7.4 | 11 | | ogb-molpcba | Graph | 0.04 | 2 | 15 | +| globalwheat | Image | 9.7 | 10.4 | 2 | | civilcomments | Text | 0.1 | 0.3 | 4.5 | | fmow | Image | 50 | 55 | 6 | | poverty | Image | 12 | 14 | 5 | From 8018dd04f0a7fc7fdebf483d641608b7c30dfa40 Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Wed, 14 Jul 2021 21:17:41 -0700 Subject: [PATCH 236/244] Fix mixed-to-test Test (ID) split for PovertyMap + use split_scheme to toggle oracle_training_set for PovertyMap and FMoW --- examples/configs/datasets.py | 4 +--- wilds/datasets/fmow_dataset.py | 19 ++++++++++++------- wilds/datasets/poverty_dataset.py | 22 ++++++++++++---------- 3 files changed, 25 insertions(+), 20 deletions(-) diff --git a/examples/configs/datasets.py b/examples/configs/datasets.py index 164a8242..e8bbe0d2 100644 --- a/examples/configs/datasets.py +++ b/examples/configs/datasets.py @@ -125,8 +125,7 @@ }, 'fmow': { 'split_scheme': 'official', - 'dataset_kwargs': { - 'oracle_training_set': False, + 'dataset_kwargs': { 'seed': 111, 'use_ood_val': True }, @@ -218,7 +217,6 @@ 'dataset_kwargs': { 'no_nl': False, 'fold': 'A', - 'oracle_training_set': False, 'use_ood_val': True }, 'model': 'resnet18_ms', diff --git a/wilds/datasets/fmow_dataset.py b/wilds/datasets/fmow_dataset.py index f8e85b9a..21b9e099 100644 --- a/wilds/datasets/fmow_dataset.py +++ b/wilds/datasets/fmow_dataset.py @@ -28,9 +28,10 @@ class FMoWDataset(WILDSDataset): The Functional Map of the World land use / building classification dataset. This is a processed version of the Functional Map of the World dataset originally sourced from https://github.com/fMoW/dataset. - Support `split_scheme` - 'official': official split, which is equivalent to 'time_after_2016' - `time_after_{YEAR}` for YEAR between 2002--2018 + Supported `split_scheme`: + - 'official': official split, which is equivalent to 'time_after_2016' + - 'mixed-to-test' + - 'time_after_{YEAR}' for YEAR between 2002--2018 Input (x): 224 x 224 x 3 RGB satellite image. @@ -63,16 +64,20 @@ class FMoWDataset(WILDSDataset): 'compressed_size': 53_893_324_800} } - def __init__(self, version=None, root_dir='data', download=False, split_scheme='official', oracle_training_set=False, seed=111, use_ood_val=True): + def __init__(self, version=None, root_dir='data', download=False, split_scheme='official', seed=111, use_ood_val=True): self._version = version self._data_dir = self.initialize_data_dir(root_dir, download) self._split_dict = {'train': 0, 'id_val': 1, 'id_test': 2, 'val': 3, 'test': 4} self._split_names = {'train': 'Train', 'id_val': 'ID Val', 'id_test': 'ID Test', 'val': 'OOD Val', 'test': 'OOD Test'} - if split_scheme=='official': - split_scheme='time_after_2016' + + self.oracle_training_set = False + if split_scheme == 'official': + split_scheme = 'time_after_2016' + elif split_scheme == 'mixed-to-test': + split_scheme = 'time_after_2016' + self.oracle_training_set = True self._split_scheme = split_scheme - self.oracle_training_set = oracle_training_set self.root = Path(self._data_dir) self.seed = int(seed) diff --git a/wilds/datasets/poverty_dataset.py b/wilds/datasets/poverty_dataset.py index 7b062002..8fcb5245 100644 --- a/wilds/datasets/poverty_dataset.py +++ b/wilds/datasets/poverty_dataset.py @@ -109,7 +109,8 @@ class PovertyMapDataset(WILDSDataset): and processed DHS survey metadata obtained from https://github.com/sustainlab-group/africa_poverty and originally from `https://dhsprogram.com/data/available-datasets.cfm`. Supported `split_scheme`: - 'official' and `countries`, which are equivalent + - 'official' and `countries`, which are equivalent + - 'mixed-to-test' Input (x): 224 x 224 x 8 satellite image, with 7 channels from LandSat and 1 nighttime light channel from DMSP/VIIRS. Already mean/std normalized. @@ -149,7 +150,7 @@ class PovertyMapDataset(WILDSDataset): def __init__(self, version=None, root_dir='data', download=False, split_scheme='official', - no_nl=False, fold='A', oracle_training_set=False, + no_nl=False, fold='A', use_ood_val=True, cache_size=100): self._version = version @@ -158,13 +159,16 @@ def __init__(self, version=None, root_dir='data', download=False, self._split_dict = {'train': 0, 'id_val': 1, 'id_test': 2, 'val': 3, 'test': 4} self._split_names = {'train': 'Train', 'id_val': 'ID Val', 'id_test': 'ID Test', 'val': 'OOD Val', 'test': 'OOD Test'} - if split_scheme=='official': + if split_scheme == 'official': split_scheme = 'countries' - self._split_scheme = split_scheme - if self._split_scheme != 'countries': - raise ValueError("Split scheme not recognized") - self.oracle_training_set = oracle_training_set + if split_scheme == 'mixed-to-test': + self.oracle_training_set = True + elif split_scheme in ['official', 'countries']: + self.oracle_training_set = False + else: + raise ValueError("Split scheme not recognized") + self._split_scheme = split_scheme self.no_nl = no_nl if fold not in {'A', 'B', 'C', 'D', 'E'}: @@ -191,11 +195,9 @@ def __init__(self, version=None, root_dir='data', download=False, else: idxs = idxs_id num_eval = 2000 - # if oracle, do 50-50 split between OOD and ID + # if oracle, sample from all countries if split == 'train' and self.oracle_training_set: idxs = subsample_idxs(incountry_folds_split, num=len(idxs_id), seed=ord(fold))[num_eval:] - elif split != 'train' and self.oracle_training_set: - eval_idxs = subsample_idxs(incountry_folds_split, num=len(idxs_id), seed=ord(fold))[:num_eval] elif split == 'train': idxs = subsample_idxs(idxs, take_rest=True, num=num_eval, seed=ord(fold)) else: From a867a7c04c84bc8a06346744e9fddd87bd8d38cc Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Wed, 14 Jul 2021 21:38:11 -0700 Subject: [PATCH 237/244] Update in-dist terminology to match updated text --- wilds/datasets/camelyon17_dataset.py | 7 ++++--- wilds/datasets/encode_dataset.py | 8 ++++++-- wilds/datasets/globalwheat_dataset.py | 8 ++++---- wilds/datasets/iwildcam_dataset.py | 2 ++ wilds/datasets/ogbmolpcba_dataset.py | 4 ++-- wilds/datasets/py150_dataset.py | 4 +++- wilds/datasets/rxrx1_dataset.py | 7 ++++--- wilds/datasets/wilds_dataset.py | 2 +- 8 files changed, 26 insertions(+), 16 deletions(-) diff --git a/wilds/datasets/camelyon17_dataset.py b/wilds/datasets/camelyon17_dataset.py index ff6e6c63..b84dd6d2 100644 --- a/wilds/datasets/camelyon17_dataset.py +++ b/wilds/datasets/camelyon17_dataset.py @@ -13,7 +13,8 @@ class Camelyon17Dataset(WILDSDataset): This is a modified version of the original CAMELYON17 dataset. Supported `split_scheme`: - 'official' or 'in-dist' + - 'official' + - 'mixed-to-test' Input (x): 96x96 image patches extracted from histopathology slides. @@ -102,8 +103,8 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' self._split_scheme = split_scheme if self._split_scheme == 'official': pass - elif self._split_scheme == 'in-dist': - # For the in-distribution oracle, + elif self._split_scheme == 'mixed-to-test': + # For the mixed-to-test setting, # we move slide 23 (corresponding to patient 042, node 3 in the original dataset) # from the test set to the training set slide_mask = (self._metadata_df['slide'] == 23) diff --git a/wilds/datasets/encode_dataset.py b/wilds/datasets/encode_dataset.py index f3009d72..afd2a871 100644 --- a/wilds/datasets/encode_dataset.py +++ b/wilds/datasets/encode_dataset.py @@ -104,7 +104,11 @@ class EncodeDataset(WILDSDataset): Note: The first time this dataset is used, it will run some one-off preprocessing scripts that will take some additional time. These scripts might cause a race condition if multiple jobs are started in parallel, - so we recommend running a single job the first time you use this dataset. + so we recommend running a single job the first time you use this dataset. + + Supported `split_scheme`: + - 'official' + - 'test-to-test' Input (x): 12800-base-pair regions of sequence with a quantified chromatin accessibility readout. @@ -197,7 +201,7 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' 'id_val': 'Validation (ID)', 'id_test': 'Test (ID)', } - elif self._split_scheme == 'in-dist': + elif self._split_scheme == 'test-to-test': splits = { 'train': { 'chroms': train_chroms, diff --git a/wilds/datasets/globalwheat_dataset.py b/wilds/datasets/globalwheat_dataset.py index 4d9a1400..7e3eefa6 100644 --- a/wilds/datasets/globalwheat_dataset.py +++ b/wilds/datasets/globalwheat_dataset.py @@ -113,8 +113,8 @@ class GlobalWheatDataset(WILDSDataset): Supported `split_scheme`: - 'official' - 'official_with_subsampled_test' - - 'fixed-test' - - 'mixed-train' + - 'test-to-test' + - 'mixed-to-test' Input (x): 1024 x 1024 RGB images of wheat field canopy starting from anthesis (flowering) to ripening. Output (y): @@ -205,12 +205,12 @@ def __init__(self, version=None, root_dir='data', download=False, split_scheme=' data_dfs['val'] = pd.read_csv(self.root / f'official_val.csv') data_dfs['test'] = pd.read_csv(self.root / f'fixed_test_test.csv') - elif split_scheme == "fixed_test": + elif split_scheme == "test-to-test": data_dfs['train'] = pd.read_csv(self.root / f'fixed_test_train.csv') data_dfs['val'] = pd.read_csv(self.root / f'official_val.csv') data_dfs['test'] = pd.read_csv(self.root / f'fixed_test_test.csv') - elif split_scheme == "mixed_train": + elif split_scheme == "mixed-to-test": data_dfs['train'] = pd.read_csv(self.root / f'mixed_train_train.csv') data_dfs['val'] = pd.read_csv(self.root / f'official_val.csv') data_dfs['test'] = pd.read_csv(self.root / f'mixed_train_test.csv') diff --git a/wilds/datasets/iwildcam_dataset.py b/wilds/datasets/iwildcam_dataset.py index d98bbb72..21b64bb1 100644 --- a/wilds/datasets/iwildcam_dataset.py +++ b/wilds/datasets/iwildcam_dataset.py @@ -17,6 +17,8 @@ class IWildCamDataset(WILDSDataset): """ The iWildCam2020 dataset. This is a modified version of the original iWildCam2020 competition dataset. + Supported `split_scheme`: + - 'official' Input (x): RGB images from camera traps Label (y): diff --git a/wilds/datasets/ogbmolpcba_dataset.py b/wilds/datasets/ogbmolpcba_dataset.py index 21e2ace5..84ecb818 100644 --- a/wilds/datasets/ogbmolpcba_dataset.py +++ b/wilds/datasets/ogbmolpcba_dataset.py @@ -13,7 +13,7 @@ class OGBPCBADataset(WILDSDataset): This dataset is directly adopted from Open Graph Benchmark, and originally curated by MoleculeNet. Supported `split_scheme`: - 'official' or 'scaffold', which are equivalent + - 'official' or 'scaffold', which are equivalent Input (x): Molecular graphs represented as Pytorch Geometric data objects @@ -108,7 +108,7 @@ def eval(self, y_pred, y_true, metadata, prediction_fn=None): - y_pred (FloatTensor): Binary logits from a model - y_true (LongTensor): Ground-truth labels - metadata (Tensor): Metadata - - prediction_fn (function): A function that turns y_pred into predicted labels. + - prediction_fn (function): A function that turns y_pred into predicted labels. Only None is supported because OGB Evaluators accept binary logits Output: - results (dictionary): Dictionary of evaluation metrics diff --git a/wilds/datasets/py150_dataset.py b/wilds/datasets/py150_dataset.py index e821c632..1aade110 100644 --- a/wilds/datasets/py150_dataset.py +++ b/wilds/datasets/py150_dataset.py @@ -13,7 +13,9 @@ class Py150Dataset(WILDSDataset): """ The Py150 dataset. - This is a modified version of the original Py150 dataset. + This is a modified version of the original Py150 dataset. + Supported `split_scheme`: + - 'official' Input (x): A Python code snippet (a sequence of tokens) Label (y): diff --git a/wilds/datasets/rxrx1_dataset.py b/wilds/datasets/rxrx1_dataset.py index bc728dce..866303e0 100644 --- a/wilds/datasets/rxrx1_dataset.py +++ b/wilds/datasets/rxrx1_dataset.py @@ -18,7 +18,8 @@ class RxRx1Dataset(WILDSDataset): This is a modified version of the original RxRx1 dataset. Supported `split_scheme`: - 'official' or 'in-dist' + - 'official' + - 'mixed-to-test' Input (x): 3-channel fluorescent microscopy images of cells @@ -65,7 +66,7 @@ def __init__(self, version=None, root_dir='data', download=False, self._version = version self._split_scheme = split_scheme - if self._split_scheme not in ['official', 'in-dist']: + if self._split_scheme not in ['official', 'mixed-to-test']: raise ValueError(f'Split scheme {self._split_scheme} not recognized') # path @@ -98,7 +99,7 @@ def __init__(self, version=None, root_dir='data', download=False, mask = ((df.dataset == 'train') & (df.site == 2)).values self._split_array[mask] = self.split_dict['id_test'] - elif split_scheme == 'in-dist': + elif split_scheme == 'mixed-to-test': # Training: 33 experiments total, 1 site per experiment (site 1) # = 19 experiments from the orig training set (site 1) # + 14 experiments from the orig test set (site 1) diff --git a/wilds/datasets/wilds_dataset.py b/wilds/datasets/wilds_dataset.py index 8812f957..48add021 100644 --- a/wilds/datasets/wilds_dataset.py +++ b/wilds/datasets/wilds_dataset.py @@ -183,7 +183,7 @@ def collate(self): def split_scheme(self): """ A string identifier of how the split is constructed, - e.g., 'standard', 'in-dist', 'user', etc. + e.g., 'standard', 'mixed-to-test', 'user', etc. """ return self._split_scheme From 892bd6f45b252b5fdb6fcb6e1833f37ea8d8520f Mon Sep 17 00:00:00 2001 From: kohpangwei Date: Fri, 16 Jul 2021 20:32:11 -0700 Subject: [PATCH 238/244] Update README.md --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 6b922b5a..4c530a34 100644 --- a/README.md +++ b/README.md @@ -103,13 +103,13 @@ These are the sizes of each of our datasets, as well as their approximate time t |-----------------|----------|--------------------|-------------------|-------------------------| | iwildcam | Image | 11 | 25 | 7 | | camelyon17 | Image | 10 | 15 | 2 | -| rxrx1 | Image | 6.9 | 7.4 | 11 | +| rxrx1 | Image | 7 | 7 | 11 | | ogb-molpcba | Graph | 0.04 | 2 | 15 | -| globalwheat | Image | 9.7 | 10.4 | 2 | +| globalwheat | Image | 10 | 10 | 2 | | civilcomments | Text | 0.1 | 0.3 | 4.5 | | fmow | Image | 50 | 55 | 6 | | poverty | Image | 12 | 14 | 5 | -| amazon | Text | 6.6 | 7 | 5 | +| amazon | Text | 7 | 7 | 5 | | py150 | Text | 0.1 | 0.8 | 9.5 | While the `camelyon17` dataset is small and fast to train on, we advise against using it as the only dataset to prototype methods on, as the test performance of models trained on this dataset tend to exhibit a large degree of variability over random seeds. From e016f84c61c01ba9e4701744059050205a91ba98 Mon Sep 17 00:00:00 2001 From: Shiori Sagawa Date: Sat, 17 Jul 2021 15:51:57 -0700 Subject: [PATCH 239/244] update README.md --- README.md | 44 +++++++++++++++++++++++++++++++++++--------- 1 file changed, 35 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 4c530a34..29adcd8a 100644 --- a/README.md +++ b/README.md @@ -131,13 +131,14 @@ Any dataset not in `` will be downloaded to ``. ### Reproducibility We have an [executable version](https://wilds.stanford.edu/codalab) of our paper on CodaLab that contains the exact commands, code, and data for the experiments reported in our paper, which rely on these scripts. Trained model weights for all datasets can also be found there. - +All configurations and hyperparameters can also be found in the `examples/configs` folder of this repo, and dataset-specific parameters are in `examples/configs/datasets.py`. ## Using the WILDS package -### Data loading +### Data The WILDS package provides a simple, standardized interface for all datasets in the benchmark. This short Python snippet covers all of the steps of getting started with a WILDS dataset, including dataset download and initialization, accessing various splits, and preparing a user-customizable data loader. +We discuss data loading in more detail in [#Data loading](#data-loading). ```py >>> from wilds import get_dataset @@ -163,10 +164,10 @@ This short Python snippet covers all of the steps of getting started with a WILD The `metadata` contains information like the domain identity, e.g., which camera a photo was taken from, or which hospital the patient's data came from, etc. ### Domain information -To allow algorithms to leverage domain annotations as well as other -groupings over the available metadata, the WILDS package provides `Grouper` objects. -These `Grouper` objects extract group annotations from metadata, allowing users to -specify the grouping scheme in a flexible fashion. +To allow algorithms to leverage domain annotations as well as other groupings over the available metadata, the WILDS package provides `Grouper` objects. +These `Grouper` objects are helper objects that extract group annotations from metadata, allowing users to specify the grouping scheme in a flexible fashion. +They are used to initialize group-aware data loaders and to implement algorithms that rely on domain annotations (e.g., Group DRO). +In the following code snippet, we initialize and use a `Grouper` that extracts the domain annotations on the iWildCam dataset, where the domain is location. ```py >>> from wilds.common.grouper import CombinatorialGrouper @@ -181,9 +182,20 @@ specify the grouping scheme in a flexible fashion. ... ... ``` -The `Grouper` can be used to prepare a group-aware data loader that, for each minibatch, first samples a specified number of groups, then samples examples from those groups. -This allows our data loaders to accommodate a wide array of training algorithms, -some of which require specific data loading schemes. +### Data loading + +For training, the WILDS package provides two types of data loaders. +The standard data loader samples examples uniformly at random from the training set, and are used for algorithms such as empirical risk minimization (ERM). +```py +>>> from wilds.common.data_loaders import get_train_loader + +# Prepare the standard data loader +>>> train_loader = get_train_loader('standard', train_data, batch_size=16) +``` + +To support other algorithms that rely on specific data loading schemes, we also provide the group data loader. +In each minibatch, it first samples a specified number of groups uniformly at random (and therefore upweights minority groups), and then samples a fixed number of examples from each of those groups. +We initialize group loaders as follows, using `Grouper` that specifies the grouping scheme. ```py # Prepare a group data loader that samples from user-specified groups @@ -193,6 +205,20 @@ some of which require specific data loading schemes. ... batch_size=16) ``` +Lastly, we also provide a data loader for evaluation, which loads examples without shuffling unlike the training loaders. + +```py +>>> from wilds.common.data_loaders import get_eval_loader + +# Get the test set +>>> test_data = dataset.get_subset('test', +... transform=transforms.Compose([transforms.Resize((224,224)), +... transforms.ToTensor()])) + +# Prepare the evaluation data loader +>>> test_loader = get_eval_loader('standard', test_data, batch_size=16) +``` + ### Evaluators The WILDS package standardizes and automates evaluation for each dataset. From 019c70a4a1f04a08d3975eaecae4529f0c1bd2e3 Mon Sep 17 00:00:00 2001 From: Shiori Sagawa Date: Sat, 17 Jul 2021 16:04:24 -0700 Subject: [PATCH 240/244] Update README.md --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 29adcd8a..c4108452 100644 --- a/README.md +++ b/README.md @@ -161,12 +161,12 @@ We discuss data loading in more detail in [#Data loading](#data-loading). ... ... ``` -The `metadata` contains information like the domain identity, e.g., which camera a photo was taken from, or which hospital the patient's data came from, etc. +The `metadata` contains information like the domain identity, e.g., which camera a photo was taken from, or which hospital the patient's data came from, etc., as well as other metadata. ### Domain information To allow algorithms to leverage domain annotations as well as other groupings over the available metadata, the WILDS package provides `Grouper` objects. These `Grouper` objects are helper objects that extract group annotations from metadata, allowing users to specify the grouping scheme in a flexible fashion. -They are used to initialize group-aware data loaders and to implement algorithms that rely on domain annotations (e.g., Group DRO). +They are used to initialize group-aware data loaders (as discussed in [#Data loading](#data-loading)) and to implement algorithms that rely on domain annotations (e.g., Group DRO). In the following code snippet, we initialize and use a `Grouper` that extracts the domain annotations on the iWildCam dataset, where the domain is location. ```py @@ -185,7 +185,7 @@ In the following code snippet, we initialize and use a `Grouper` that extracts t ### Data loading For training, the WILDS package provides two types of data loaders. -The standard data loader samples examples uniformly at random from the training set, and are used for algorithms such as empirical risk minimization (ERM). +The standard data loader shuffles examples in the training set, and is used for the standard approach of empirical risk minimization (ERM), where we minimize the average loss. ```py >>> from wilds.common.data_loaders import get_train_loader @@ -194,7 +194,7 @@ The standard data loader samples examples uniformly at random from the training ``` To support other algorithms that rely on specific data loading schemes, we also provide the group data loader. -In each minibatch, it first samples a specified number of groups uniformly at random (and therefore upweights minority groups), and then samples a fixed number of examples from each of those groups. +In each minibatch, the group loader first samples a specified number of groups uniformly at random (upweighting minority groups as a result), and then samples a fixed number of examples from each of those groups. We initialize group loaders as follows, using `Grouper` that specifies the grouping scheme. ```py From 891798380d96ed0bed7befcaf07e90c21c5b0eb9 Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Sat, 17 Jul 2021 21:34:17 -0700 Subject: [PATCH 241/244] Update README --- README.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index c4108452..40e0437b 100644 --- a/README.md +++ b/README.md @@ -194,7 +194,8 @@ The standard data loader shuffles examples in the training set, and is used for ``` To support other algorithms that rely on specific data loading schemes, we also provide the group data loader. -In each minibatch, the group loader first samples a specified number of groups uniformly at random (upweighting minority groups as a result), and then samples a fixed number of examples from each of those groups. +In each minibatch, the group loader first samples a specified number of groups, and then samples a fixed number of examples from each of those groups. +(By default, the groups are sampled uniformly at random, which upweights minority groups as a result. This can be toggled by the `uniform_over_groups` parameter.) We initialize group loaders as follows, using `Grouper` that specifies the grouping scheme. ```py @@ -205,7 +206,7 @@ We initialize group loaders as follows, using `Grouper` that specifies the group ... batch_size=16) ``` -Lastly, we also provide a data loader for evaluation, which loads examples without shuffling unlike the training loaders. +Lastly, we also provide a data loader for evaluation, which loads examples without shuffling (unlike the training loaders). ```py >>> from wilds.common.data_loaders import get_eval_loader From f5c54d7df634aec26770e7baf797f0079cb8d478 Mon Sep 17 00:00:00 2001 From: Tony Lee Date: Sun, 18 Jul 2021 00:09:36 -0700 Subject: [PATCH 242/244] update --- examples/evaluate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/evaluate.py b/examples/evaluate.py index 85ba26b7..58fe295c 100644 --- a/examples/evaluate.py +++ b/examples/evaluate.py @@ -204,7 +204,7 @@ def evaluate_replicate_for_globalwheat( predicted_labels = torch.load(path_to_predictions) subset: WILDSSubset = dataset.get_subset(split) metadata: torch.Tensor = subset.metadata_array - true_labels = list(itemgetter(*subset.indices)(subset.dataset.y_array)) + true_labels = [subset.dataset.y_array[idx] for idx in subset.indices] return dataset.eval(predicted_labels, true_labels, metadata)[0] From a7a0d1932b2728cb5b4061369437a181604c30ca Mon Sep 17 00:00:00 2001 From: Tony Lee Date: Sun, 18 Jul 2021 00:11:00 -0700 Subject: [PATCH 243/244] remove unused import --- examples/evaluate.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/evaluate.py b/examples/evaluate.py index 58fe295c..6931a17d 100644 --- a/examples/evaluate.py +++ b/examples/evaluate.py @@ -3,7 +3,6 @@ import os import urllib.request from ast import literal_eval -from operator import itemgetter from typing import Dict, List from urllib.parse import urlparse From 419baa530e0ff91e6f21901224739faba897f42a Mon Sep 17 00:00:00 2001 From: Pang Wei Koh Date: Sun, 18 Jul 2021 14:55:53 -0700 Subject: [PATCH 244/244] Waterbirds adjusted average accuracy computation --- wilds/datasets/waterbirds_dataset.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/wilds/datasets/waterbirds_dataset.py b/wilds/datasets/waterbirds_dataset.py index 9caeb4cb..79dd5714 100644 --- a/wilds/datasets/waterbirds_dataset.py +++ b/wilds/datasets/waterbirds_dataset.py @@ -122,13 +122,31 @@ def eval(self, y_pred, y_true, metadata, prediction_fn=None): are predicted labels. - y_true (LongTensor): Ground-truth labels - metadata (Tensor): Metadata - - prediction_fn (function): A function that turns y_pred into predicted labels + - prediction_fn (function): A function that turns y_pred into predicted labels Output: - results (dictionary): Dictionary of evaluation metrics - results_str (str): String summarizing the evaluation metrics """ metric = Accuracy(prediction_fn=prediction_fn) - return self.standard_group_eval( + + results, results_str = self.standard_group_eval( metric, self._eval_grouper, y_pred, y_true, metadata) + + # For Waterbirds, the validation and test sets are constructed to be more balanced + # compared to the training set. + # To compute the actual average accuracy over the empirical (training) distribution, + # we therefore weight each groups according to their frequency in the training set. + + results['adj_acc_avg'] = ( + (results['acc_y:landbird_background:land'] * 3498 + + results['acc_y:landbird_background:water'] * 184 + + results['acc_y:waterbird_background:land'] * 56 + + results['acc_y:waterbird_background:water'] * 1057) / + (3498 + 184 + 56 + 1057)) + + del results['acc_avg'] + results_str = f"Adjusted average acc: {results['adj_acc_avg']:.3f}\n" + '\n'.join(results_str.split('\n')[1:]) + + return results, results_str