In [2]:
import Random: Xoshiro, AbstractRNG

In [11]:
rng = Xoshiro(42)

Xoshiro(0xa379de7eeeb2a4e8, 0x953dccb6b532b3af, 0xf597b8ff8cfd652a, 0xccd7337c571680d1, 0xc90c4a0730db3f7e)

In [14]:
function generate_copy_dataset(rng::AbstractRNG; num_samples::Int=1000, max_length::Int=50, vocab_size::Int=100)
    dataset = []
    for _ in 1:num_samples
        length = rand(rng, 5:max_length)
        sequence = [rand(1:vocab_size) for _ in 1:length]
        push!(dataset, (sequence, sequence))  # Input and target identical
    end
    return dataset
end

generate_copy_dataset (generic function with 1 method)

In [15]:
copy_data = generate_copy_dataset(rng);

In [None]:
function generate_reversal_dataset(rng::AbstractRNG; num_samples=1000, max_length=50, vocab_size=100)
    dataset = []
    for _ in 1:num_samples
        length = rand(rng, 5:max_length)
        sequence = [rand(rng, 1:vocab_size) for _ in 1:length]
        reversed_sequence = reverse(sequence)
        push!(dataset, (sequence, reversed_sequence))
    end
    return dataset
end

generate_reversal_dataset (generic function with 1 method)

In [None]:
reversal_data = generate_reversal_dataset(rng);

In [25]:
function generate_retrieval_dataset(rng::AbstractRNG; num_samples=1000, context_length=100, vocab_size=100, special_token=999)
    dataset = []
    for _ in 1:num_samples
        haystack = [rand(rng, 1:vocab_size) for _ in 1:context_length-1]
        needle_position = rand(rng, 1:context_length-1)
        needle_value = rand(rng, 1:vocab_size)
        insert!(haystack, needle_position, needle_value)
        query = vcat(special_token, needle_position)
        target = [haystack[needle_position]]
        push!(dataset, (vcat(haystack, query), target))
    end
    return dataset
end

generate_retrieval_dataset (generic function with 1 method)

In [26]:
needle_data = generate_retrieval_dataset(rng);

In [28]:
needle_data[1]

([38, 83, 93, 54, 31, 94, 59, 43, 11, 11  …  61, 41, 82, 92, 25, 56, 27, 40, 999, 32], [22])

In [32]:
needle_data[1][1][32]

22

In [36]:
function generate_sorting_dataset(rng::AbstractRNG; num_samples=1000, max_length=20, vocab_size=100)
    dataset = []
    for _ in 1:num_samples
        length = rand(rng, 5:max_length)
        sequence = [rand(rng, 1:vocab_size) for _ in 1:length]
        sorted_sequence = sort(sequence)
        push!(dataset, (sequence, sorted_sequence))
    end
    return dataset
end

generate_sorting_dataset (generic function with 1 method)

In [38]:
sorting_data = generate_sorting_dataset(rng);

In [39]:
function generate_pattern_dataset(rng::AbstractRNG; num_samples::Int=1000, pattern_length::Int=3, vocab_size::Int=100, special_token::Int=999)
    dataset = []
    for _ in 1:num_samples
        context_length = rand(20:50)
        context = [rand(rng, 1:vocab_size) for _ in 1:context_length]
        pattern = [rand(rng, 1:vocab_size) for _ in 1:pattern_length]
        # Insert pattern at random position
        insert_pos = rand(rng, 1:context_length - pattern_length + 1)
        context[insert_pos:insert_pos+pattern_length-1] = pattern
        push!(dataset, (vcat(context, special_token, pattern[1:end-1]), [pattern[end]]))
    end
    return dataset
end

generate_pattern_dataset (generic function with 2 methods)

In [41]:
pattern_data = generate_pattern_dataset(rng);

In [42]:
pattern_data[1]

([77, 15, 31, 90, 19, 19, 97, 1, 13, 66  …  19, 91, 25, 65, 90, 92, 8, 999, 97, 1], [13])

In [9]:
function generate_depth_dataset(rng::AbstractRNG; num_samples::Int=1000, max_length::Int=50, vocab_size::Int=100)
    dataset = []
    for _ in 1:num_samples
        length = rand(5:max_length)
        sequence = [rand(0:vocab_size-1) for _ in 1:length]
        depth_values = [rand() for _ in 1:length]
        # Assign higher depth values to important tokens
        important_indices = shuffle(1:length)[1:div(length,3)]
        for idx in important_indices
            depth_values[idx] += 2.0
        end
        target = [sequence[i] for i in important_indices]
        push!(dataset, ((sequence, depth_values), target))
    end
    return dataset
end

generate_depth_dataset (generic function with 1 method)