In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import random

In [None]:
trial_types_12 = [1, 1, 1, 1, 2, 2, 2, 2]
trial_types_34 = [3, 3, 3, 3, 4, 4, 4, 4]
trial_types = [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4]
iti_lengths = ['150"', '270"', '210"', '330"', '150"', '270"', '210"', '330"', '150"', '270"', '210"', '330"', 
               '150"', '270"', '210"', '330"', '150"', '270"', '210"', '330"', '150"', '270"', '210"', '330"', 
               '150"', '270"', '210"', '330"', '150"', '270"', '210"', '330"']

In [None]:
iti_lengths = ['90"', '105"', '135"', '150"', '90"', '105"', '135"', '150"', '90"', '105"', '135"', '150"',
               '90"', '105"', '135"', '150"', '90"', '105"', '135"', '150"', '90"', '105"', '135"', '150"',
               '90"', '105"', '135"', '150"', '90"', '105"', '135"', '150"', '90"', '105"', '135"', '150"',
               '90"', '105"', '135"', '150"', '90"', '105"', '135"', '150"', '90"', '105"', '135"', '150"',
               '90"', '105"', '135"', '150"', '90"', '105"', '135"', '150"', '90"', '105"', '135"', '150"',
               '90"', '105"', '135"', '150"']

In [None]:
def check_sequence(sequence, n_row):
    if n_row <= 1:
        raise ValueError("n_row must be >= 2")
    trials = np.unique(sequence)

    for trial in trials:
        start, end = None, None
        for i, x in enumerate(sequence):
            if x == trial and start is None:
                start = i
            if x != trial and start is not None and end is None:
                end = i-1
            if start is not None and end is not None:
                if end-start+1 > n_row:
                    return False
                else:
                    start, end = None, None

    return True

In [None]:
def check_trial_prior(sequence, trial, n_prior):
    trials = np.unique(sequence)
    counts = {trial: 0 for trial in trials}
    
    for i, x in enumerate(sequence):
        if x == trial:
            counts[sequence[i-1]] += 1

    return not any(trial_val > n_prior for trial_val in counts.values())

def check_prior(sequence, n_prior):
    trials = np.unique(sequence)
    return all(check_trial_prior(sequence, trial, n_prior) for trial in trials)

In [None]:
def find_sequence(trial_types, n_row, n_prior, n_sequences):
    """Finds sequences that have fewer than x of the same trial type or reward outcome in a row. 
    """
    sequences = []
    
    while len(sequences) < n_sequences:
        np.random.shuffle(trial_types)
        sequence = trial_types.copy()
        outcome_sequence = [x % 2 for x in sequence]

        if not check_sequence(sequence, n_row):
            continue
        if not check_sequence(outcome_sequence, n_row):
            continue
        if not check_prior(sequence, n_prior):
            continue

        sequences.append(sequence)
    
    return sequences

In [None]:
def find_iti(iti_lengths, n_row, n_sequences):
    """Finds sequences that have fewer than x of the same trial type or reward outcome in a row.    
    """
    sequences = []
    
    while len(sequences) < n_sequences:
        np.random.shuffle(iti_lengths)
        sequence = iti_lengths.copy()

        if not check_sequence(sequence, n_row):
            continue
        sequences.append(', '.join(sequence))
    
    return sequences

In [None]:
sequences_12 = find_sequence(trial_types_12, n_row=2, n_prior=4, n_sequences=20)
print(sequences_12[0])

In [None]:
sequences_34 = find_sequence(trial_types_34, n_row=2, n_prior=4, n_sequences=20)
print(sequences_34[0])

In [None]:
sequences = find_sequence(trial_types, n_row=2, n_prior=4, n_sequences=20)
print(sequences[0])

In [None]:
iti_sequences = find_iti(iti_lengths, n_row=2, n_sequences=20)
print(iti_sequences[0])

In [None]:
for sequence in sequences:
    print(sequence)

In [None]:
for sequence in sequences_12:
    print(sequence)

In [None]:
for sequence in sequences_34:
    print(sequence)

In [None]:
for sequence in iti_sequences:
    print(sequence)

In [None]:
def check_sequences(sequences, n_row, n_prior):
    verified_sequences = []
    for sequence in sequences:
        outcome_sequence = [x % 2 for x in sequence]

        if not check_sequence(sequence, n_row):
            continue
        if not check_sequence(outcome_sequence, n_row):
            continue
        if not check_prior(sequence, n_prior):
            continue
        verified_sequences.append(sequence)
    return verified_sequences

In [None]:
verified = check_sequences(sequences, n_row=3, n_prior=4)