In [14]:
import functools
import json
import math
from typing import Tuple, TypeVar
import warnings

import haiku as hk
import jax
import jax.numpy as jnp
import optax
import numpy as np
import pandas as pd
import plotnine as gg
import matplotlib.pyplot as plt

T = TypeVar('T')
Pair = Tuple[T, T]

gg.theme_set(gg.theme_bw())
warnings.filterwarnings('ignore')

In [15]:
# Read from a JSON file
with open('braille_translation.json', 'r') as file:
    data = json.load(file)

bin_to_braille = data['bin_to_braille']
eng_to_bin = data['eng_to_bin']

In [16]:
def encode_character(char: str) -> int:
    """Encode a single character to its ASCII value."""
    return ord(char)

def decode_character(code: int) -> str:
    """Decode an ASCII value back to a character."""
    return chr(code)

def encode_braille_binary(binary_str: str) -> np.ndarray:
    """Encode a Braille binary string to an array of integers."""
    return np.array([int(bit) for bit in binary_str], dtype=np.int32)

# Example decoding function for Braille (you might need to adjust this based on your model's output)
def decode_braille_binary(binary_array: np.ndarray) -> str:
    """Decode an array of integers back to a Braille binary string."""
    return ''.join(map(str, binary_array))


In [17]:
def generate_braille_data(seq_len: int, data_size: int) -> Pair[np.ndarray]:
    english_chars = list(eng_to_bin.keys())
    eng_seqs = np.random.choice(english_chars, (data_size, seq_len))

    # Encode English sequences
    vectorized_encode_char = np.vectorize(encode_character)
    encoded_eng_seqs = vectorized_encode_char(eng_seqs)
    # Assuming encoded_eng_seqs is of shape [data_size, seq_len]
    encoded_eng_seqs = encoded_eng_seqs[:, :, np.newaxis]  # Now shape [data_size, seq_len, 1]


    # Encode Braille sequences
    braille_seqs = np.vectorize(eng_to_bin.get)(eng_seqs)
    vectorized_encode_braille = np.vectorize(encode_braille_binary, signature='()->(n)')
    encoded_braille_seqs = vectorized_encode_braille(braille_seqs)

    return encoded_eng_seqs, encoded_braille_seqs



In [18]:
def split_data(eng_seqs, braille_seqs, train_size, valid_size):
    # Split the data into training and validation sets
    train_x = eng_seqs[:train_size]
    train_y = braille_seqs[:train_size]

    valid_x = eng_seqs[train_size:train_size + valid_size]
    valid_y = braille_seqs[train_size:train_size + valid_size]

    return (train_x, train_y), (valid_x, valid_y)


In [19]:
class Dataset:
    """An iterator over a dataset, revealing batch_size elements at a time."""

    def __init__(self, xy: Pair[np.ndarray], batch_size: int):
        self._x, self._y = xy
        self._batch_size = batch_size
        self._num_batches = self._x.shape[0] // batch_size
        self._idx = 0

    def __next__(self) -> Pair[np.ndarray]:
        if self._idx >= self._num_batches:
            raise StopIteration

        start = self._idx * self._batch_size
        end = start + self._batch_size
        x, y = self._x[start:end], self._y[start:end]
        self._idx += 1
        return x, y

    def __iter__(self):
        return self


In [20]:
repetitions = 500
BATCH_SIZE = 8  # Number of examples per batch
SEQ_LEN = 64  # Length of each sequence
DATA_SIZE = 27 * repetitions  # where repetitions is the number of times you repeat the dataset
TRAIN_SIZE = int(0.75 * DATA_SIZE)  # Example: 75% for training
VALID_SIZE = DATA_SIZE - TRAIN_SIZE  # Remaining for validation
# Generate the data
eng_seqs, braille_seqs = generate_braille_data(SEQ_LEN, DATA_SIZE)

# Split the data
train, valid = split_data(eng_seqs, braille_seqs, TRAIN_SIZE, VALID_SIZE)

# Create dataset objects
train_ds = Dataset(train, BATCH_SIZE)
valid_ds = Dataset(valid, BATCH_SIZE)


In [21]:
def unroll_net(seqs: jnp.ndarray):
    core = hk.LSTM(128)
    batch_size = seqs.shape[1]
    outs, state = hk.dynamic_unroll(core, seqs, core.initial_state(batch_size))
    return hk.BatchApply(hk.Linear(6))(outs), state


model = hk.transform(unroll_net)


def train_model(train_ds: Dataset) -> hk.Params:
    """Initializes and trains a model on train_ds, returning the final params."""
    rng = jax.random.PRNGKey(428)
    opt = optax.adam(1e-3)

    @jax.jit
    def loss(params, x, y):
        pred, _ = model.apply(params, None, x)
        # Use binary cross-entropy loss
        return jnp.mean(optax.sigmoid_binary_cross_entropy(pred, y))

    @jax.jit
    def update(step, params, opt_state, x, y):
        l, grads = jax.value_and_grad(loss)(params, x, y)
        grads, opt_state = opt.update(grads, opt_state)
        params = optax.apply_updates(params, grads)
        return l, params, opt_state

    # Initialize state.
    sample_x, _ = next(train_ds)
    params = model.init(rng, sample_x)
    opt_state = opt.init(params)

    for step in range(1000):
        try:
            x, y = next(train_ds)
        except StopIteration:
            break 
        train_loss, params, opt_state = update(step, params, opt_state, x, y)
        if step % 100 == 0:
            print("Step {}: train loss {}".format(step, train_loss))

    return params


In [22]:
trained_params = train_model(train_ds)

Step 0: train loss 0.6983954310417175
Step 100: train loss 0.576724648475647
Step 200: train loss 0.5514483451843262
Step 300: train loss 0.5117329955101013
Step 400: train loss 0.508243203163147
Step 500: train loss 0.5050299167633057
Step 600: train loss 0.4777008891105652
Step 700: train loss 0.48730605840682983
Step 800: train loss 0.48745638132095337
Step 900: train loss 0.4798251986503601


In [23]:
def predict_braille(model, trained_params, encoded_char):
    # The input encoded_char should already be in shape [1, 1, 1]
    pred, _ = model.apply(trained_params, None, encoded_char)
    # Apply sigmoid and round to get binary values
    pred_binary = jnp.round(jax.nn.sigmoid(pred))
    return pred_binary


In [24]:
input_char = 'c'  # Example character
encoded_input = encode_character(input_char)

# Reshape to [1, 1, 1] for [batch_size, sequence_length, features]
encoded_input = jnp.array([[[encoded_input]]], dtype=jnp.float32)

In [27]:
decode_braille_binary(encoded_input)

'[[99.]]'