Skip to content

Commit

Permalink
Refactor atom storage
Browse files Browse the repository at this point in the history
  • Loading branch information
samirelanduk committed Apr 7, 2018
1 parent 484ce5b commit 0adccb0
Show file tree
Hide file tree
Showing 9 changed files with 102 additions and 170 deletions.
7 changes: 3 additions & 4 deletions atomium/structures/chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def verify(sequence):
:returns: ``True`` if the test passes."""

residues = set()
for atom in sequence.atoms():
for atom in sequence._atoms:
residues.add(atom.residue)
if residues:
seq = [list(residues)[0]]
Expand Down Expand Up @@ -102,9 +102,8 @@ class Chain(ResidueSequence, Molecule):
def __init__(self, *atoms, **kwargs):
Molecule.__init__(self, *atoms, **kwargs)
ResidueSequence.verify(self)
for cluster in self._atoms.values():
for atom in cluster:
atom._chain = self
for atom in self._atoms:
atom._chain = self


def __repr__(self):
Expand Down
5 changes: 2 additions & 3 deletions atomium/structures/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,5 @@ class Model(AtomicStructure):

def __init__(self, *atoms):
AtomicStructure.__init__(self, *atoms)
for cluster in self._atoms.values():
for atom in cluster:
atom._model = self
for atom in self._atoms:
atom._model = self
92 changes: 42 additions & 50 deletions atomium/structures/molecules.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,27 +24,27 @@ class AtomicStructure:
:raises TypeError: if non-atoms or AtomicStructures are given."""

def __init__(self, *atoms):
atoms_ = set()
self._atoms = set()
for atom in atoms:
if not isinstance(atom, (Atom, AtomicStructure)):
raise TypeError(
"AtomicStructures need atoms, not '{}'".format(atom)
)
[atoms_.add(atom) if isinstance(atom, Atom)
else atoms_.update(atom.atoms()) for atom in atoms]
self._atoms = {id: set() for id in set([a.id for a in atoms_])}
for atom in atoms_:
self._atoms[atom.id].add(atom)
[self._atoms.add(atom) if isinstance(atom, Atom)
else self._atoms.update(atom._atoms) for atom in atoms]
self._id_atoms = {id: set() for id in set([a.id for a in self._atoms])}
for atom in self._atoms:
self._id_atoms[atom.id].add(atom)


def __repr__(self):
return "<{} ({} atoms)>".format(
self.__class__.__name__, len(self.atoms())
self.__class__.__name__, len(self._atoms)
)


def __contains__(self, member):
return member in self.atoms()
return member in self._atoms


@atom_query
Expand All @@ -63,7 +63,7 @@ def atoms(self):
:param bool metal: If ``False``, metal atoms will be excluded.
:rtype: ``set``"""

return reduce(operator.or_, self._atoms.values())
return set(self._atoms)


def atom(self, *args, **kwargs):
Expand All @@ -85,10 +85,10 @@ def atom(self, *args, **kwargs):
:rtype: ``Atom``"""

if "id" in kwargs:
atoms = self._atoms.get(kwargs["id"])
atoms = self._id_atoms.get(kwargs["id"])
if not atoms: atoms = set()
elif len(args) == 1:
atoms = self._atoms.get(args[0])
atoms = self._id_atoms.get(args[0])
else:
atoms = self.atoms(*args, **kwargs)
if not atoms: atoms = set()
Expand All @@ -103,11 +103,12 @@ def add_atom(self, atom):

if not isinstance(atom, Atom):
raise TypeError("Can only add atoms, not '{}'".format(atom))
if atom.id in self._atoms:
self._atoms[atom.id].add(atom)
if atom.id in self._id_atoms:
self._id_atoms[atom.id].add(atom)
else:
self._atoms[atom.id] = {atom}
self._id_atoms[atom.id] = {atom}
atom.__dict__["_" + self.__class__.__name__.lower()] = self
self._atoms.add(atom)


def remove_atom(self, atom):
Expand All @@ -116,9 +117,10 @@ def remove_atom(self, atom):
:param Atom atom: The atom to remove."""

try:
self._atoms[atom.id].remove(atom)
if not self._atoms[atom.id]: del self._atoms[atom.id]
self._id_atoms[atom.id].remove(atom)
if not self._id_atoms[atom.id]: del self._id_atoms[atom.id]
atom.__dict__["_" + self.__class__.__name__.lower()] = None
self._atoms.remove(atom)
except KeyError: pass


Expand All @@ -144,7 +146,7 @@ def add(self, structure):

if not isinstance(structure, AtomicStructure):
raise TypeError("{} is not an atomic structure".format(structure))
for atom in structure.atoms():
for atom in structure._atoms:
self.add_atom(atom)


Expand All @@ -156,7 +158,7 @@ def remove(self, structure):

if not isinstance(structure, AtomicStructure):
raise TypeError("{} is not an atomic structure".format(structure))
for atom in structure.atoms():
for atom in structure._atoms:
self.remove_atom(atom)


Expand All @@ -169,7 +171,7 @@ def residues(self, id=None, name=None):
:rtype: ``Residue``"""

res = set()
for atom in self.atoms(): res.add(atom.residue)
for atom in self._atoms: res.add(atom.residue)
try: res.remove(None)
except KeyError: pass
if id: res = set(filter(lambda r: r.id == id, res))
Expand Down Expand Up @@ -200,7 +202,7 @@ def molecules(self, id=None, name=None, generic=False, water=True):
:rtype: ``Molecule``"""

molecules = set()
for atom in self.atoms(): molecules.add(atom.molecule)
for atom in self._atoms: molecules.add(atom.molecule)
try: molecules.remove(None)
except KeyError: pass
if id: molecules = set(filter(lambda r: r.id == id, molecules))
Expand Down Expand Up @@ -240,7 +242,7 @@ def chains(self, id=None, name=None):
:rtype: ``Chain``"""

chains = set()
for atom in self.atoms():
for atom in self._atoms:
chains.add(atom.chain)
try:
chains.remove(None)
Expand Down Expand Up @@ -286,7 +288,7 @@ def translate(self, *args, **kwargs):
after translating - the default is 12 decimal places but this can be\
set to ``None`` if no rounding is to be done."""

for atom in self.atoms():
for atom in self._atoms:
atom.translate(*args, **kwargs)


Expand All @@ -306,7 +308,7 @@ def rotate(self, angle, axis, degrees=False, trim=12):
raise ValueError("{} is not a valid axis".format(axis))
angle = math.radians(angle) if degrees else angle
matrix = Atom.generate_rotation_matrix(None, angle, axis)
atoms = list(self.atoms())
atoms = list(self._atoms)
for atom, vector in zip(atoms, [atom.location for atom in atoms]):
atom._x, atom._y, atom._z = matrix.dot(vector)
self.trim(trim)
Expand All @@ -318,7 +320,7 @@ def mass(self):
:rtype: ``float``"""

return round(sum([atom.mass for atom in self.atoms()]), 12)
return round(sum([atom.mass for atom in self._atoms]), 12)


@property
Expand All @@ -328,7 +330,7 @@ def charge(self):
:rtype: ``float``"""

return round(sum([atom.charge for atom in self.atoms()]), 12)
return round(sum([atom.charge for atom in self._atoms]), 12)


@property
Expand All @@ -337,7 +339,7 @@ def formula(self):
:rtype: ``Counter``"""

return Counter([atom.element for atom in self.atoms()])
return Counter([atom.element for atom in self._atoms])


@property
Expand All @@ -348,10 +350,9 @@ def center_of_mass(self):
:returns: (x, y, z) ``tuple``"""

mass = self.mass
atoms = self.atoms()
average_x = sum([atom.x * atom.mass for atom in atoms]) / mass
average_y = sum([atom.y * atom.mass for atom in atoms]) / mass
average_z = sum([atom.z * atom.mass for atom in atoms]) / mass
average_x = sum([atom.x * atom.mass for atom in self._atoms]) / mass
average_y = sum([atom.y * atom.mass for atom in self._atoms]) / mass
average_z = sum([atom.z * atom.mass for atom in self._atoms]) / mass
return (average_x, average_y, average_z)


Expand All @@ -364,7 +365,7 @@ def radius_of_gyration(self):
:rtype: ``float``"""

center_of_mass = self.center_of_mass
atoms = self.atoms()
atoms = self._atoms
square_deviation = sum(
[atom.distance_to(center_of_mass) ** 2 for atom in atoms]
)
Expand All @@ -391,7 +392,7 @@ def pairing_with(self, structure):

if not isinstance(structure, AtomicStructure):
raise TypeError("{} is not an AtomicStructure".format(structure))
atoms, other_atoms = list(self.atoms()), list(structure.atoms())
atoms, other_atoms = list(self._atoms), list(structure._atoms)
if len(atoms) != len(other_atoms):
raise ValueError("{} and {} have different numbers of atoms".format(
self, structure
Expand Down Expand Up @@ -444,7 +445,7 @@ def rmsd_with(self, structure, superimpose=False):

pairing = self.pairing_with(structure)
if superimpose:
atoms = list(self.atoms())
atoms = list(self._atoms)
locations = [atom.location for atom in atoms]
self.superimpose_onto(structure)
sd = sum(a1.distance_to(a2) ** 2 for a1, a2 in pairing.items())
Expand Down Expand Up @@ -479,7 +480,7 @@ def grid(self, size=1, margin=0):
coordinates. The default is 0.
:rtype: ``tuple``"""

atom_locations = [atom.location for atom in self.atoms()]
atom_locations = [atom.location for atom in self._atoms]
dimension_values = []
for dimension in range(3):
coordinates = [loc[dimension] for loc in atom_locations]
Expand Down Expand Up @@ -524,18 +525,11 @@ def atoms_in_sphere(self, x, y, z, radius):
if radius < 0:
raise ValueError("{} is not a valid radius".format(radius))
atoms = filter(
lambda a: a.distance_to((x, y, z)) <= radius, self.atoms()
lambda a: a.distance_to((x, y, z)) <= radius, self._atoms
)
return set(atoms)









def to_file_string(self, file_format, description=None):
"""Converts a structure to a filestring. Currently supported file formats
are: .xyz and .pdb.
Expand Down Expand Up @@ -600,9 +594,8 @@ def __init__(self, *atoms, id=None, name=None):
raise TypeError("Molecule name {} is not a string".format(name))
self._id = id
self._name = name
for cluster in self._atoms.values():
for atom in cluster:
atom._molecule = self
for atom in self._atoms:
atom._molecule = self


def __repr__(self):
Expand Down Expand Up @@ -646,7 +639,7 @@ def model(self):
:rtype: ``Model``"""

for atom in self.atoms():
for atom in self._atoms:
return atom.model


Expand Down Expand Up @@ -699,9 +692,8 @@ class Residue(Molecule):
def __init__(self, *atoms, **kwargs):
Molecule.__init__(self, *atoms, **kwargs)
self._next, self._previous = None, None
for cluster in self._atoms.values():
for atom in cluster:
atom._residue = self
for atom in self._atoms:
atom._residue = self


@property
Expand Down Expand Up @@ -789,7 +781,7 @@ def chain(self):
:rtype: ``Chain``"""

for atom in self.atoms():
for atom in self._atoms:
return atom.chain


Expand Down
Loading

0 comments on commit 0adccb0

Please sign in to comment.