In [16]:
from transduction.smc.smc import SMC, Particle, logsumexp
from transduction.smc.aa import create_dna_translator, get_source_lm_probs, score_sequence
from transduction.lazy_recursive import LazyRecursive
from transduction import Precover

In [17]:
def score(seq, fst, max_length=100000):
    pc = Precover(fst, seq)
    q, r = pc.decomposition
    contributions = []
    for seq in q.language(max_length=max_length):
        contributions.append(score_sequence(seq, log_space=True) )
        
    total_log_prob = logsumexp(contributions)
    return total_log_prob

def logp_next(seq, fst, vocabulary):
    unnormalized_log_probs = {}
    
    for token in vocabulary:
        next_seq = seq + token
        unnormalized_log_probs[token] = score(next_seq, fst)
    
    log_probs_list = list(unnormalized_log_probs.values())
    log_Z = logsumexp(log_probs_list)
    normalized_dist = {}
    
    for token, log_prob in unnormalized_log_probs.items():
        if log_prob == float('-inf'):
            normalized_dist[token] = float('-inf')
        else:
            normalized_dist[token] = log_prob - log_Z
            
    return normalized_dist

In [18]:
seq = "QMQMQ"
fst = create_dna_translator()
smc = SMC(fst, LazyRecursive, seq, get_source_lm_probs, num_particles=100)

In [19]:
p = Particle.initial(fst.I, fst, seq)
p

π(x=, #states=1, w=0.00, Active)

In [20]:
smc.get_valid_proposal_tokens(p)

['C', 'C']

In [21]:
particles = smc()
particles

[π(x=CAGATGCAGATGCAG, #states=1, w=-20.82, Univ),
 π(x=CAGATGCAGATGCAG, #states=1, w=-20.82, Univ),
 π(x=CAAATGCAGATGCAA, #states=1, w=-21.22, Univ),
 π(x=CAGATGCAAATGCAG, #states=1, w=-21.22, Univ),
 π(x=CAGATGCAAATGCAA, #states=1, w=-21.22, Univ),
 π(x=CAGATGCAGATGCAA, #states=1, w=-20.82, Univ),
 π(x=CAGATGCAGATGCAG, #states=1, w=-20.82, Univ),
 π(x=CAAATGCAGATGCAG, #states=1, w=-21.22, Univ),
 π(x=CAGATGCAAATGCAA, #states=1, w=-21.22, Univ),
 π(x=CAAATGCAAATGCAG, #states=1, w=-21.63, Univ),
 π(x=CAGATGCAGATGCAG, #states=1, w=-20.82, Univ),
 π(x=CAAATGCAGATGCAG, #states=1, w=-21.22, Univ),
 π(x=CAGATGCAAATGCAG, #states=1, w=-21.22, Univ),
 π(x=CAGATGCAGATGCAA, #states=1, w=-20.82, Univ),
 π(x=CAAATGCAAATGCAG, #states=1, w=-21.63, Univ),
 π(x=CAAATGCAAATGCAG, #states=1, w=-21.63, Univ),
 π(x=CAAATGCAAATGCAG, #states=1, w=-21.63, Univ),
 π(x=CAAATGCAAATGCAG, #states=1, w=-21.63, Univ),
 π(x=CAGATGCAAATGCAG, #states=1, w=-21.22, Univ),
 π(x=CAGATGCAGATGCAA, #states=1, w=-20.82, Univ),


In [22]:
smc.get_probs()

-21.10240551988702

In [23]:
score(seq, fst)

-21.102405519887025

In [24]:
dist = SMC.get_dist(fst, LazyRecursive, seq, get_source_lm_probs, 100)
dist

{'': -0.6948858197656271,
 'R': -3.4421021900541255,
 'M': -4.836321989521576,
 'I': -3.473748729071776,
 'P': -2.954759794666625,
 'V': -3.7589618248941257,
 'H': -5.035945015951281,
 'T': -3.621544417016654,
 'D': -4.032468402244543,
 'E': -3.9795130221966026,
 'L': -2.706159072257311,
 'N': -4.338733969130811,
 'Y': -4.388016722736918,
 'F': -3.4968471797584577,
 'C': -4.3515621333849275,
 'S': -3.138825031065398,
 'A': -3.292452287693063,
 'W': -5.3140934654293055,
 'Q': -5.052367672435555,
 'K': -4.363368645971914,
 '*': -3.9378370235096973,
 'G': -3.67379077318709}

In [25]:
logp_next(seq, fst, fst.B)

{'': -0.6931471805599472,
 'R': -3.426515189646448,
 'M': -4.853631545286593,
 'I': -3.467337184166702,
 'P': -2.9565115604007133,
 'V': -3.7297014486341915,
 'H': -5.035953102080548,
 'T': -3.649658740960657,
 'D': -4.017383521085971,
 'E': -4.017383521085971,
 'L': -2.6422624605642078,
 'N': -4.3428059215206005,
 'Y': -4.422848629194135,
 'F': -3.54737989184024,
 'C': -4.422848629194135,
 'S': -3.1349943408874985,
 'A': -3.324236340526028,
 'W': -5.339139361068291,
 'Q': -5.035953102080548,
 'K': -4.3428059215206005,
 '*': -3.9528449999484003,
 'G': -3.7297014486341915}