In [1]:
# changing in order to not reset paths
%cd -q ../../project/parallel_synthesis/experiments

from copy import copy
from protera_stability.config.lazy import LazyCall as L
from protera_stability.config.instantiate import instantiate
from protera_stability.config.common.data import (
    base_dataset,
    base_dataloader,
    base_sampler,
    get_train_val_indices,
)

# Does our LazyConfig work?

In [2]:
dataset = instantiate(base_dataset.data)

In [3]:
train_idx, valid_idx = get_train_val_indices(dataset, 0.8)

assert len(set(train_idx).intersection(set(valid_idx))) == 0

In [4]:
base_sampler.name  # this should change for each experiment!

''

## *No Sampler* 
> this is because we aren't using a "special" sampling method, therefore we directly pass the indices

In [5]:
a_sampler = copy(base_sampler)
a_sampler.random.indices.set_indices = train_idx
a_sampler.random.indices.random_percent = 1.0

all_data_sampler = instantiate(a_sampler.random)

len(all_data_sampler), len(train_idx), next(iter(all_data_sampler))

(6564, 6564, 369)

## Instantiate Diversity Sampler

In [6]:
div_sampler = copy(base_sampler)
div_sampler.diversity.set_sequences.dataset = base_dataset.data
div_sampler.diversity.set_sequences.set_indices = train_idx
div_sampler.diversity.max_size = int(len(dataset) * 0.8)

diversity_sampler = instantiate(div_sampler.diversity)

assert len(set(diversity_sampler.indices).intersection(set(valid_idx))) == 0

len(diversity_sampler), diversity_sampler.stopped_by

  rank_zero_warn(f"No correct seed found, seed set to {seed}")
Global seed set to 215523697


(6141, 'CUTOFF')

## Instantiate Random Sampler

In [7]:
rand_sampler = copy(base_sampler)
rand_sampler.random.indices.dataset = base_dataset
rand_sampler.random.indices.set_indices = train_idx

random_sampler = instantiate(rand_sampler.random)

assert len(set(random_sampler.indices).intersection(set(valid_idx))) == 0

len(random_sampler), len(dataset) * 0.3, len(train_idx)

(0, 2461.2, 6564)

### Validation Sampler

In [8]:
val_sampler = copy(base_sampler)
val_sampler.random.indices.set_indices = valid_idx
val_sampler.random.indices.random_percent = 1.0
valid_sampler = instantiate(val_sampler.random)

len(valid_sampler), len(valid_idx)

(1640, 1640)

## Instantiate Dataloaders

In [15]:
dl = copy(base_dataloader)

In [21]:
# dl.train.sampler = a_sampler.random
dl.train.sampler = div_sampler.diversity
instantiate(dl.train)

Global seed set to 215523697


<torch.utils.data.dataloader.DataLoader at 0x7f1ca495e410>

In [22]:
dl.valid.sampler = val_sampler.random
instantiate(dl.valid)

<torch.utils.data.dataloader.DataLoader at 0x7f1c9bc60050>

In [23]:
instantiate(dl.test)

<torch.utils.data.dataloader.DataLoader at 0x7f1cac5a9e10>

## Check if we have all our data

In [24]:
from protera_stability.data.dataset import ProteinStabilityDataset

In [25]:
count_samples = 0
for X, y in instantiate(dl.train):
    count_samples += X.shape[0]

int(
    len(
        ProteinStabilityDataset(
            proteins_path="../data/stability_train.h5",
            ret_dict=False,
        )
    )
    * 0.8
), count_samples  # difference might be due to diversity cutoff

Global seed set to 215523697


(6563, 6141)

In [14]:
count_samples = 0
for X, y in instantiate(dl.valid):
    count_samples += X.shape[0]

count_samples == len(valid_idx)

True

In [20]:
count_samples = 0
for X, y in instantiate(dl.test):
    count_samples += X.shape[0]

count_samples == len(
    ProteinStabilityDataset(
        proteins_path="../data/stability_test.h5",
        ret_dict=False,
    )
)

True