In [10]:
import torch
import einx
from einops import pack

In [85]:
batch_size = 1
seq_len = 10
molecule_feats = torch.stack([
    torch.arange(seq_len).unsqueeze(0).repeat(batch_size, 1), # token_index
], dim=-1).reshape((batch_size, seq_len, 1))

def onehot(x, bins):
    dist_from_bins = einx.subtract('... i, j -> ... i j', x, bins)
    indices = dist_from_bins.abs().min(dim = -1, keepdim=True).indices
    one_hots = torch.nn.functional.one_hot(indices.long(), num_classes = len(bins))
    return one_hots

In [87]:
token_idx

tensor([[0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]])

In [103]:
r_max = 32
s_max = 2

token_idx = torch.arange(seq_len).unsqueeze(0).expand(batch_size, seq_len)

# Compute residue, token, and chain distances
diff_token_idx = einx.subtract('b i, b j -> b i j', token_idx, token_idx)
# Mask for same residue, chain, and entity
mask_same_entity = torch.ones((batch_size, seq_len, seq_len, 1))
# Compute clipped distances
d_token = torch.clip(diff_token_idx + r_max, 0, 2 * r_max)
d_res = torch.full((batch_size, seq_len, seq_len), r_max)
d_chain = torch.full((batch_size, seq_len, seq_len), 2*s_max+1)
# Define bins
r_arange = torch.arange(2*r_max + 2)
s_arange = torch.arange(2*s_max + 2)
# Assign 1-hot encoding of distances
a_rel_pos = onehot(d_res, r_arange)
a_rel_token = onehot(d_token, r_arange)
a_rel_chain = onehot(d_chain, s_arange)
# Concatenate tensors and project
out, _ = pack((a_rel_pos, a_rel_token, mask_same_entity, a_rel_chain), 'b i j *')