Skip to content

Commit

Permalink
⚡ cleaning up old code
Browse files Browse the repository at this point in the history
  • Loading branch information
Robert Turnbull committed Oct 15, 2023
1 parent 78cc482 commit 2b0a828
Showing 1 changed file with 0 additions and 153 deletions.
153 changes: 0 additions & 153 deletions corgi/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,6 @@
from seqbank import SeqBank
from .seqtree import SeqTree

@define
class AccessionDetail:
validation:bool
node_id:int
type:int


def open_path(path:Path):
path = Path(path)
Expand All @@ -63,16 +57,6 @@ def __call__(self, accession:str):
return self.seqtree[accession].node_id


class SeqTreeTypeGetter:
def __init__(self, seqtree:SeqTree):
self.seqtree = seqtree

def __call__(self, accession:str):
return self.seqtree[accession].type.value




@delegates()
class StratifiedDL(TfmdDL):
def __init__(self, dataset=None, bs=None, groups=None, **kwargs):
Expand Down Expand Up @@ -117,149 +101,12 @@ def show_batch(x: TensorDNA, y, samples, ctxs=None, max_n=20, trunc_at=150, **kw
return ctxs


def create_datablock_refseq(categories, validation_column="validation", validation_prob=0.2, vocab=None) -> DataBlock:

# Check if there is a validation column in the dataset otherwise use a random splitter
if validation_column:
splitter = ColSplitter(validation_column)
else:
splitter = RandomSplitter(valid_pct=validation_prob, seed=42)

return DataBlock(
blocks=(TransformBlock, CategoryBlock(vocab=vocab)),
splitter=splitter,
get_y=ColReader("category"),
item_tfms=RowToTensorDNA(categories),
)


def create_datablock(seq_length=None, validation_column="validation", validation_prob=0.2, vocab=None) -> DataBlock:

# Check if we need to slice to a specific sequence length
if seq_length:
item_tfms = SliceTransform(seq_length)
else:
item_tfms = None

# Check if there is a validation column in the dataset otherwise use a random splitter
if validation_column:
splitter = ColSplitter(validation_column)
else:
splitter = RandomSplitter(valid_pct=validation_prob, seed=42)

return DataBlock(
blocks=(TransformBlock, CategoryBlock(vocab=vocab)),
splitter=splitter,
get_x=get_sequence_as_tensor,
get_y=ColReader("category"),
item_tfms=item_tfms,
)


class DataloaderType(str, Enum):
PLAIN = "PLAIN"
WEIGHTED = "WEIGHTED"
STRATIFIED = "STRATIFIED"


class AccessionSplitter:
def __init__(self, accession_details:dict):
self.accession_details = accession_details

def __call__(self, objects):
validation_indexes = mask2idxs(self.accession_details[object].validation for object in objects)
return IndexSplitter(validation_indexes)(objects)


class AccessionGetter():
def __init__(self, accession_details:dict, attribute:str):
self.accession_details = accession_details
self.attribute = attribute

def __call__(self, accession:str):
return getattr(self.accession_details[accession], self.attribute)


def create_seqbank_dataloaders(
csv:Path,
seqbank:Path,
batch_size:int=64,
validation_partition:int=1,
deform_lambda: float = None,
validation_seq_length:int=1_000,
verbose:bool=False,
label_smoothing:float=0.0,
gamma:float=0.0,
**kwargs
):
seqbank = SeqBank(seqbank)
csv = Path(csv)
df = pd.read_csv(csv)

# Build Hiearchy Tree
assert 'hierarchy' in df.columns, f"Cannot find 'hierarchy' column in {csv}."
classification_tree, classification_to_node, classification_to_node_id = create_hierarchy(
df['hierarchy'].unique(),
label_smoothing=label_smoothing,
gamma=gamma,
)

accession_details = {}
assert 'partition' in df.columns, f"Cannot find 'partition' column in {csv}."
assert 'type' in df.columns, f"Cannot find 'type' column in {csv}."
missing = set()
for _, row in track(df.iterrows(), description="Reading CSV", total=len(df)):
accession_details[row['accession']] = AccessionDetail(
validation=(row['partition'] == validation_partition),
node_id=classification_to_node_id[row['hierarchy']],
type=row['type'],
)
# if row['accession'] not in seqbank:
# missing.add(row['accession'])

if missing:
with open("MISSING.txt", "w") as f:
for accession in missing:
print(accession, file=f)
raise ValueError(f"WARNING: {len(missing)} accessions in {csv} are missing from {seqbank}. Written to MISSING.txt")

del df

# Set up batch transforms
before_batch = [
RandomSliceBatch(only_split_index=0),
DeterministicSliceBatch(seq_length=validation_seq_length, only_split_index=1),
]
if deform_lambda is not None:
before_batch.append(DeformBatch(deform_lambda=deform_lambda))

dataloaders_kwargs = dict(bs=batch_size, drop_last=False, before_batch=before_batch)

getters = [
GetTensorDNA(seqbank),
AccessionGetter(accession_details, 'node_id'),
AccessionGetter(accession_details, 'type'),
]

blocks = (
TransformBlock,
TransformBlock,
TransformBlock,
# CategoryBlock(vocab=["nuclear", "mitochondrion", "plastid", "plasmid"], sort=False),
)
datablock = DataBlock(
blocks=blocks,
splitter=AccessionSplitter(accession_details),
getters=getters,
n_inp=1,
)
dls = datablock.dataloaders(set(accession_details.keys()), verbose=verbose, **dataloaders_kwargs)

dls.classification_tree = classification_tree

return dls


def create_seqtree_dataloaders(
seqtree:SeqTree,
seqbank:SeqBank,
Expand Down

0 comments on commit 2b0a828

Please sign in to comment.