Skip to content

Commit

Permalink
Merge pull request #11 from til-unc/fast-case-standard-format
Browse files Browse the repository at this point in the history
Fast case standard format
  • Loading branch information
iskandr committed Dec 3, 2020
2 parents 6273858 + 84eada4 commit 9de90a1
Show file tree
Hide file tree
Showing 7 changed files with 190 additions and 25 deletions.
23 changes: 12 additions & 11 deletions mhcgnomes/allele.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Tuple, Union
from typing import List, Tuple, Union, Iterable


from .gene import Gene
Expand All @@ -26,9 +26,9 @@ class Allele(ResultWithGene):
def __init__(
self,
gene : Gene,
allele_fields : Union[List[str], Tuple[str]],
annotations : Union[List[str], Tuple[str]] = [],
mutations : Union[List[Mutation], Tuple[Mutation]] = [],
allele_fields : Iterable[str],
annotations : Iterable[str] = (),
mutations : Iterable[Mutation] = (),
raw_string=None):
ResultWithGene.__init__(
self,
Expand Down Expand Up @@ -118,23 +118,24 @@ def split_allele_fields(cls, allele_fields):
@classmethod
def get_with_gene(
cls,
gene,
allele_fields,
annotations=[],
mutations=[],
raw_string=None):
gene : Gene,
allele_fields: Iterable[str],
annotations: Union[Iterable[str], None] = None,
mutations: Union[Iterable[Mutation], None] = None,
raw_string: Union[str, None] = None):
if gene is None:
return None

allele_fields = cls.split_allele_fields(allele_fields)

if len(allele_fields) == 0:
return None

if annotations is None:
annotations = []
annotations = ()

if mutations is None:
mutations = []
mutations = ()

return Allele(
gene=gene,
Expand Down
28 changes: 17 additions & 11 deletions mhcgnomes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,29 +29,38 @@ def unique(xs : Iterable):
unique_set.add(xi)
return result

def arg_to_cache_key(x, _primitive_types={bool, int, str, float}):

def arg_to_cache_key(x):
if x is None:
return None

t = type(x)
if t is int or t is str or t is bool or t is float:

if t is int or t is str or t is bool or t is float or t is type:
return x

if t is list or t is tuple:
if len(x) == 0:
value = ()
return ()
elif len(x) == 1:
value = (arg_to_cache_key(x[0]),)
return (arg_to_cache_key(x[0]),)
else:
value = tuple([arg_to_cache_key(xi) for xi in x])
elif t is dict:
value = tuple([
(arg_to_cache_key(k), arg_to_cache_key(v))
for (k, v) in x.items()])
if len(x) == 0:
value = ()
elif len(x) == 1:
((k, v),) = x.items()
value = (arg_to_cache_key(k), arg_to_cache_key(v))
else:
value = tuple([
(arg_to_cache_key(k), arg_to_cache_key(v))
for (k, v) in x.items()])
else:
value = x
return (t.__name__, value)


def cache(fn):
"""
Memoization function which tries to freeze non-hashable types like
Expand All @@ -60,10 +69,7 @@ def cache(fn):
"""
cache = {}
def cached_fn(*args, **kwargs):
if not args:
args_key = ()
else:
args_key = arg_to_cache_key(args)
args_key = arg_to_cache_key(args)
if not kwargs:
kwargs_key = ()
else:
Expand Down
13 changes: 13 additions & 0 deletions mhcgnomes/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from .result_sorting import pick_best_result
from .serotype import Serotype
from .species import Species, infer_species_from_prefix
from .standard_format import parse_standard_allele_format
from .token import Token
from .tokenize import tokenize

Expand Down Expand Up @@ -1045,6 +1046,18 @@ def parse_single_token_to_multiple_candidates(
seq = token.seq
raw_string = token.raw_string

standard_result = parse_standard_allele_format(
seq,
raw_string=raw_string,
default_species=default_species)

if standard_result:

if self.verbose:
print(f"""=== Standard format result """)
print(standard_result)
return [standard_result]

# list containing all candidate results
parse_candidates = []

Expand Down
3 changes: 1 addition & 2 deletions mhcgnomes/species.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def get_multiple(cls, species_name):
def get(cls, species_name):
if type(species_name) is Species:
return species_name
elif type(species_name) is not str:
elif species_name is None or type(species_name) is not str:
return None

species_objects = cls.get_multiple(species_name)
Expand Down Expand Up @@ -316,7 +316,6 @@ def get_mhc_class_of_gene(self, gene_name):
gene_name = self.normalize_gene_name_if_exists(gene_name)
return self.gene_name_to_mhc_class.get(gene_name)

@cache
def get_known_allele(self, gene_name, allele_name):
gene_name_candidates = {gene_name}

Expand Down
117 changes: 117 additions & 0 deletions mhcgnomes/standard_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
from typing import Union

from .species import Species
from .gene import Gene
from .allele import Allele

_standard_allele_regex_str = (
"([a-zA-Z]+-)?" # optional species
"([a-zA-Z]+\d?|\d{1,2})\*" # gene, either e.g. "A1" or "88"
"(\d{2,3})" # mandatory first digit field with between 2 and 3 digits
"(:\d{2,3})?" # optional second allele field with up to 3 digits
"(:\d\d)?" # optional third allele field
"(:\d\d)?" # optional fourth allele field
"([a-zA-Z])?" # optional annotation at the end of the allele
)
_standard_allele_regex = re.compile(_standard_allele_regex_str)


def parse_standard_allele_format(
seq: str,
raw_string: Union[str, None] = None,
default_species: Union[str, Species, None] = None):
"""
Parse alleles which are in a standard format, such as::
Species-Gene*001:01:01:01
Between one and four allele fields are allowed and the number of digits
must be 2 or 3 in the first two fields and exactly two in the last two
fields. A single character annotation is allowed at the end.
Parameters
----------
seq : str
Sequence to parse
raw_string : str
Raw string the sequence was derived from
default_species : str, Species, or None
If no species is provided in the sequence, should one be assumed?
Returns
-------
Allele or None
"""
match = _standard_allele_regex.fullmatch(seq)

if not match:
return None

groups = match.groups()

if len(groups) < 3:
return None

species_prefix, gene_name = groups[:2]

if gene_name is None:
return None

if species_prefix is None:
species = Species.get(default_species)
elif len(species_prefix) >= 2 and species_prefix[-1] == "-":
species = Species.get(species_prefix[:-1])
else:
return None

if species is None:
return None

gene = Gene.get(species, gene_name)

if gene is None:
return None

allele_fields = []
for i, raw_allele_field in enumerate(groups[2:-1]):
if raw_allele_field is None:
break
elif i == 0:
allele_fields.append(raw_allele_field)
else:
# skip the initial ':' in all fields after first
allele_fields.append(raw_allele_field[1:])

if len(allele_fields) == 0:
return None

annotation = groups[-1]

if annotation:
annotations = [annotation.upper()]
else:
annotations = []

if raw_string is None:
raw_string = seq

return Allele.get_with_gene(
gene,
allele_fields,
annotations=annotations,
raw_string=raw_string)
28 changes: 28 additions & 0 deletions test/test_standard_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from mhcgnomes.standard_format import parse_standard_allele_format
from mhcgnomes import Allele
from nose.tools import eq_

def test_parse_standard_allele_format_HLA_A_02_01():
result = parse_standard_allele_format(seq="HLA-A*02:01")
eq_(result, Allele.get("HLA", "A", "02", "01"))

def test_parse_standard_allele_format_HLA_A_02_01_01():
result = parse_standard_allele_format(seq="HLA-A*02:01:01")
eq_(result, Allele.get("HLA", "A", "02", "01", "01"))

def test_parse_standard_allele_format_HLA_A_02_01_01_01():
result = parse_standard_allele_format(seq="HLA-A*02:01:01:01")
eq_(result, Allele.get("HLA", "A", "02", "01", "01", "01"))


def test_parse_standard_allele_format_HLA_A_02_01_01_02L():
result = parse_standard_allele_format(seq="HLA-A*02:01:01:02L")
eq_(result, Allele.get("HLA", "A", "02", "01", "01", "02", annotation="L"))

def test_parse_standard_allele_format_DLA_88_021_01():
result = parse_standard_allele_format(seq="DLA-88*01:01")
eq_(result, Allele.get("DLA", "88", "01", "01"))

def test_parse_standard_allele_format_A_02_01():
result = parse_standard_allele_format(seq="A*02:01", default_species="HLA")
eq_(result, Allele.get("HLA", "A", "02", "01"))
3 changes: 2 additions & 1 deletion test/timing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import numpy as np

import pandas as pd
from time import time
from mhcgnomes import parse

def run(n_repeats=3, filename="MHC_prot.fasta"):
if filename.endswith("csv"):
df = pd.read_csv(filename)
Expand Down

0 comments on commit 9de90a1

Please sign in to comment.