In [1]:
from transduction.smc.smc import SMC, Particle, logsumexp
from transduction.smc.aa import create_dna_translator, get_source_lm_probs, score_sequence
from transduction.lazy_recursive import LazyRecursive
from transduction import Precover

In [2]:
def score(seq, fst, max_length=100000):
    pc = Precover(fst, seq)
    q, r = pc.decomposition
    contributions = []
    for seq in q.language(max_length=max_length):
        contributions.append(score_sequence(seq, log_space=True) )
        
    total_log_prob = logsumexp(contributions)
    return total_log_prob

def logp_next(seq, fst, vocabulary):
    unnormalized_log_probs = {}
    
    for token in vocabulary:
        next_seq = seq + token
        unnormalized_log_probs[token] = score(next_seq, fst)
    
    log_probs_list = list(unnormalized_log_probs.values())
    log_Z = logsumexp(log_probs_list)
    normalized_dist = {}
    
    for token, log_prob in unnormalized_log_probs.items():
        if log_prob == float('-inf'):
            normalized_dist[token] = float('-inf')
        else:
            normalized_dist[token] = log_prob - log_Z
            
    return normalized_dist

In [None]:
seq = "QMQMQ"
fst = create_dna_translator()
smc = SMC(fst, LazyRecursive, seq, get_source_lm_probs, num_particles=10000)

In [4]:
p = Particle.initial(fst.I, fst, seq)
p

π(x=, #states=1, w=0.00, Active)

In [5]:
smc.get_valid_proposal_tokens(p)

['C', 'C']

In [6]:
particles = smc()
particles

[π(x=CAGATGCAA, #states=1, w=-1.43, Univ)]

In [7]:
smc.get_probs()

-14.026231589279924

In [8]:
score(seq, fst)

-12.742216077280451

In [9]:
dist = SMC.get_dist(fst, LazyRecursive, seq, get_source_lm_probs, 10000)
dist

{'': -0.22949371663893636,
 'Y': -5.057807453941235,
 'R': -5.057807453941239,
 'S': -3.841412129616744,
 'N': -4.6523423458330715,
 'G': -4.6523423458330715,
 'M': -5.057807453941235,
 'H': -4.940024418284855,
 'I': -5.057807453941239,
 'Q': -4.940024418284855,
 'V': -5.175590489597623,
 'F': -4.482443309037679,
 'T': -5.581055597705786,
 'L': -4.770125381489455,
 '*': -4.6523423458330715,
 'P': -4.076978200929512,
 'W': -5.868737670157566,
 'K': -4.246877237724908,
 'C': -5.4632725620494025,
 'E': -4.246877237724908,
 'D': -4.246877237724908,
 'A': -3.959195165273128}

In [10]:
logp_next(seq, fst, fst.B)

{'': -0.6931471805599454,
 'Y': -4.422848629194135,
 'R': -3.4265151896464463,
 'S': -3.1349943408874985,
 'N': -4.3428059215206005,
 'G': -3.7297014486341915,
 'M': -4.853631545286589,
 'H': -5.035953102080548,
 'I': -3.467337184166702,
 'Q': -5.035953102080548,
 'V': -3.7297014486341915,
 'F': -3.5473798918402384,
 'T': -3.649658740960657,
 'L': -2.6422624605642078,
 '*': -3.9528449999484003,
 'P': -2.9565115604007097,
 'W': -5.339139361068291,
 'K': -4.3428059215206005,
 'C': -4.422848629194135,
 'E': -4.017383521085971,
 'D': -4.017383521085971,
 'A': -3.324236340526028}