In [12]:
import torch
import numpy as np

In [64]:
def get_atoms(pdb_fn, chains):
    residues = []
    coords = []
    with open(pdb_fn) as f:
        lines = f.readlines()
        for l in lines:
            if l[:4] == "ATOM":
                if l[21] in chains:
                    coords.append([float(l[30:38]), float(l[38:46]), float(l[46:54])])
                    # RESNAME [17:20]
                    # CHAIN [21]
                    # RESNUM + INSERTION CODE [22:27]
                    residues.append((l[17:20], l[21], l[22:27].strip()))
    return np.array(coords), residues

In [39]:
def atom_pairs(atoms):
    a = torch.Tensor(atoms)
    d = torch.cdist(a, a, p=2.0)
    return np.argwhere(d.numpy() < 3.0)
    


In [43]:
def residue_pairs(residues, atom_pairs):
    pairs = []
    for p in atom_pairs:
        if residues[p[0]] != residues[p[1]]:
            rp = (residues[p[0]], residues[p[1]])
            if rp not in pairs:
                pairs.append(rp)  
    return pairs

In [70]:
atoms, residues = get_atoms("/home/ABTLUS/jose.pereira/Downloads/7a98.pdb", ['A', 'E', 'F'])

In [71]:
p = atom_pairs(atoms)

In [73]:
rp = residue_pairs(residues, p)

In [75]:
rp

[(('GLN', 'A', '14'), ('CYS', 'A', '15')),
 (('CYS', 'A', '15'), ('GLN', 'A', '14')),
 (('CYS', 'A', '15'), ('VAL', 'A', '16')),
 (('CYS', 'A', '15'), ('CYS', 'A', '136')),
 (('VAL', 'A', '16'), ('CYS', 'A', '15')),
 (('VAL', 'A', '16'), ('ASN', 'A', '17')),
 (('ASN', 'A', '17'), ('VAL', 'A', '16')),
 (('ASN', 'A', '17'), ('LEU', 'A', '18')),
 (('ASN', 'A', '17'), ('SER', 'A', '254')),
 (('LEU', 'A', '18'), ('ASN', 'A', '17')),
 (('LEU', 'A', '18'), ('THR', 'A', '19')),
 (('THR', 'A', '19'), ('LEU', 'A', '18')),
 (('THR', 'A', '19'), ('THR', 'A', '20')),
 (('THR', 'A', '20'), ('THR', 'A', '19')),
 (('THR', 'A', '20'), ('ARG', 'A', '21')),
 (('ARG', 'A', '21'), ('THR', 'A', '20')),
 (('ARG', 'A', '21'), ('THR', 'A', '22')),
 (('THR', 'A', '22'), ('ARG', 'A', '21')),
 (('THR', 'A', '22'), ('GLN', 'A', '23')),
 (('GLN', 'A', '23'), ('THR', 'A', '22')),
 (('GLN', 'A', '23'), ('LEU', 'A', '24')),
 (('LEU', 'A', '24'), ('GLN', 'A', '23')),
 (('LEU', 'A', '24'), ('PRO', 'A', '25')),
 (('LEU',

In [74]:
len(rp)

5364

In [None]:
ATOM      1  N   GLN A  14     321.191 317.056 245.806  1.00411.42           N  

In [41]:
p = atom_pairs(a)

In [44]:
residue_pairs(residues, p)

[(0, 1),
 (0, 2),
 (0, 3),
 (0, 4),
 (0, 5),
 (0, 6),
 (0, 7),
 (0, 8),
 (0, 9),
 (1, 0),
 (1, 2),
 (1, 3),
 (1, 4),
 (1, 5),
 (1, 6),
 (1, 7),
 (1, 8),
 (1, 9),
 (2, 0),
 (2, 1),
 (2, 3),
 (2, 4),
 (2, 5),
 (2, 6),
 (2, 7),
 (2, 8),
 (2, 9),
 (3, 0),
 (3, 1),
 (3, 2),
 (3, 4),
 (3, 5),
 (3, 6),
 (3, 7),
 (3, 8),
 (3, 9),
 (4, 0),
 (4, 1),
 (4, 2),
 (4, 3),
 (4, 5),
 (4, 6),
 (4, 7),
 (4, 8),
 (4, 9),
 (5, 0),
 (5, 1),
 (5, 2),
 (5, 3),
 (5, 4),
 (5, 6),
 (5, 7),
 (5, 8),
 (5, 9),
 (6, 0),
 (6, 1),
 (6, 2),
 (6, 3),
 (6, 4),
 (6, 5),
 (6, 7),
 (6, 8),
 (6, 9),
 (7, 0),
 (7, 1),
 (7, 2),
 (7, 3),
 (7, 4),
 (7, 5),
 (7, 6),
 (7, 8),
 (7, 9),
 (8, 0),
 (8, 1),
 (8, 2),
 (8, 3),
 (8, 4),
 (8, 5),
 (8, 6),
 (8, 7),
 (8, 9),
 (9, 0),
 (9, 1),
 (9, 2),
 (9, 3),
 (9, 4),
 (9, 5),
 (9, 6),
 (9, 7),
 (9, 8)]

In [26]:
a = torch.ones((50,3))

In [8]:
a[0,1] = 2

In [28]:
a[:,1:]

tensor([[1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.],
        [1., 1.]])

In [10]:
d = torch.cdist(a, a, p=2.0)

In [11]:
d > 0.5

tensor([[False,  True,  True,  ...,  True,  True,  True],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        ...,
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False],
        [ True, False, False,  ..., False, False, False]])

In [14]:
d.numpy()

array([[0., 1., 1., ..., 1., 1., 1.],
       [1., 0., 0., ..., 0., 0., 0.],
       [1., 0., 0., ..., 0., 0., 0.],
       ...,
       [1., 0., 0., ..., 0., 0., 0.],
       [1., 0., 0., ..., 0., 0., 0.],
       [1., 0., 0., ..., 0., 0., 0.]], dtype=float32)

In [23]:
pairs = np.argwhere(d.numpy())

In [22]:
residues = np.array([[i]*5 for i in range(10)]).flatten()

In [24]:
residue_pairs = []
for p in pairs:
    if residues[p[0]] != residues[p[1]]:
        rp = (residues[p[0]], residues[p[1]])
        if rp not in residue_pairs:
            residue_pairs.append(rp)

In [25]:
residue_pairs

[(0, 1),
 (0, 2),
 (0, 3),
 (0, 4),
 (0, 5),
 (0, 6),
 (0, 7),
 (0, 8),
 (0, 9),
 (1, 0),
 (2, 0),
 (3, 0),
 (4, 0),
 (5, 0),
 (6, 0),
 (7, 0),
 (8, 0),
 (9, 0)]