In [31]:
import torch
from gpn.phylogpn import model, tokenizer # Wrapper around HuggingFace

# Example data
seqs = [
    "TATAAA",
    "GGCCAATCT",
    "CACGTG",
    "AGGTCACGT",
    "GCCAGCC",
    "GGGGATTTCC"
]

# Output length is input length minus 480 (the receptive field size minus 1)
pad_token = tokenizer.pad_token
pad_size = 481 // 2
pad_sequence = lambda seq: pad_token * pad_size + seq + pad_token * pad_size
padded_seqs = [pad_sequence(seq) for seq in seqs]
input_tensor = tokenizer(padded_seqs, return_tensors="pt", padding=True)["input_ids"]

with torch.no_grad():
    padded_embeddings = model.get_embeddings(input_tensor)
    padded_logits = model(input_tensor) # These are log rate parameters for the F81 model

embeddings = []
logits = []

for i in range(len(seqs)):
    length = len(seqs[i])
    embeddings.append(padded_embeddings[i, :length])
    logits.append({})

    for k in "ACGT":
        logits[-1][k] = padded_logits[k][i, :length]

In [32]:
# `embeddings` is a list of tensors, one per item in the batch, each containing embeddings for each position in the sequence
embeddings[0]

tensor([[ 0.2399,  0.2717,  0.1193,  ..., -0.2694,  1.2095, -0.8143],
        [-0.3725,  1.2785, -0.0264,  ..., -0.5638, -0.1288, -1.1086],
        [-0.7089,  0.8101, -1.5966,  ..., -0.4742,  1.5401, -0.8086],
        [-0.6684,  1.1999, -1.6881,  ...,  0.0269, -0.1773,  0.2568],
        [-0.1105,  0.2217, -0.3428,  ..., -0.4531,  1.1346, -0.8851],
        [-0.8363,  1.2306,  0.3084,  ..., -0.7149,  0.5221,  0.4968]])

In [34]:
# `logits` is a list of dictionaries, one per item in the batch, each containing the log rate parameters for the F81 model
logits[0]

{'A': tensor([ 0.3503,  0.1109,  0.5503, -0.1649,  0.6182,  0.1694]),
 'C': tensor([-0.0261,  0.3540, -0.3431,  0.5716,  0.4910,  0.3965]),
 'G': tensor([0.4168, 0.4671, 0.3895, 0.1255, 0.1161, 0.3556]),
 'T': tensor([-0.2667,  0.0614, -0.2464,  0.5956, -0.1491, -0.1738])}

In [40]:
# Get likelihoods

likelihood_list = []

for logit_dict in logits:
    logit_tensor = torch.stack([logit_dict[k] for k in "ACGT"])
    likelihood_tensor = torch.softmax(logit_tensor, dim=1)
    likelihood_dict = {k: likelihood_tensor[i] for i, k in enumerate("ACGT")}
    likelihood_list.append(likelihood_dict)

likelihood_list[0]

{'A': tensor([0.1740, 0.1369, 0.2125, 0.1039, 0.2274, 0.1452]),
 'C': tensor([0.1218, 0.1781, 0.0887, 0.2214, 0.2042, 0.1858]),
 'G': tensor([0.1834, 0.1929, 0.1784, 0.1370, 0.1358, 0.1725]),
 'T': tensor([0.1250, 0.1736, 0.1276, 0.2961, 0.1406, 0.1372])}

In [42]:
# Log likelihood ratios are used to score whether a substitution is more likely under one model than another
# For example, the log likelihood ratio of a C to T substitution at position 2 in the first sequence is:
logits[0]["T"][1] - logits[0]["C"][1]

tensor(-0.2926)