In [22]:
from jax import numpy as jnp
import jax
from jax import grad,vmap
from jax import random
import matplotlib.pyplot as plt
from jax.tree_util import register_pytree_node_class
import numpy as np
from jax import lax as jlax
from jax.tree_util import register_pytree_node_class
import json
#import copyself.components
import jaxlib

In [54]:
@register_pytree_node_class
class Parameter:
    """
    A class to represent a parameter with a name and a value, supporting basic arithmetic operations
    and integration with the JAX pytree system.

    Attributes:
    -----------
    name : str
        The name of the parameter.
    value : ndarray
        The value of the parameter, typically a NumPy array or similar.
    shape : tuple
        The shape of the value.
    """
    def __init__(self,name,value):
        self.name = name
        self.value = value
        self.shape = value.shape
    def __sub__(self,param):
        if isinstance(param,Parameter):
            return Parameter(self.name,self.value-param.value)
        raise TypeError(f"unsupported operand type(s) for -: {type(param)} and 'Parameter'")
    def __add__(self,other):
        if isinstance(other,Parameter):
            return Parameter(self.name,self.value+other.value)
        if isinstance(other,float):
            return Parameter(self.name,self.value+other)
        raise TypeError(f"unsupported operand type(s) for +: {type(other)} and 'Parameter'")
    def __mul__(self,factor):
        if isinstance(factor,float):
            return Parameter(self.name,factor*self.value)
        raise TypeError(f'Cannot multiply a Parameter with {type(factor)}')
    def __rmul__(self,factor):
        if isinstance(factor,float):
            return Parameter(self.name,factor*self.value)
        raise TypeError(f'Cannot multiply a Parameter with {type(factor)}')
    def __pow__(self,factor):
        return Parameter(self.name,self.value**factor)
    def __truediv__(self,other):
        if isinstance(other,Parameter):
            return Parameter(self.name,self.value/other.value)
        if isinstance(other,float):
            return Parameter(self.name,self.value/other)
        raise TypeError(f'Cannot divide a Parameter with {type(other)}')
    def tree_flatten(self):
        children = (self.value,)
        aux_data = (self.name,)
        return (children, aux_data)
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*aux_data,*children)

In [55]:
@register_pytree_node_class
class LinearParams:
    """
    A class to represent the parameters of a linear layer, specifically the weights, which are stored
    as a Parameter object.

    Attributes:
    -----------
    name : str
        The name of the linear layer parameters.
    weights : Parameter
        The weights of the linear layer, stored as a Parameter object.
    """
    def __init__(self,name,weights):
        self.name = name
        if isinstance(weights,Parameter):
            self.weights = weights
        else:
            self.weights = Parameter("W",weights)
    def __sub__(self,other):
        if isinstance(other,LinearParams):
            return LinearParams(self.name,self.weights-other.weights)
        raise TypeError(f"unsupported operand type(s) for -: {type(other)} and 'LinearParams'")
    def __add__(self,other):
        if isinstance(other,LinearParams) :
            return LinearParams(self.name,self.weights+other.weights)
        if isinstance(other,float) :
            return LinearParams(self.name,self.weights+other)
        raise TypeError(f"unsupported operand type(s) for +: {type(other)} and 'LinearParams'")
    def __mul__(self,other):
        if isinstance(other,float):
            return LinearParams(self.name,self.weights*other)
        raise TypeError(f"Cannot multiply a 'LinearParams' with {type(other)}")
    def __rmul__(self,other):
        if isinstance(other,float):
            return LinearParams(self.name,self.weights*other)
        raise TypeError(f"Cannot multiply a 'LinearParams' with {type(other)}")
    def __truediv__(self,other):
        if isinstance(other,LinearParams) :
            return LinearParams(self.name,self.weights/other.weights)
        if isinstance(other,float):
            return LinearParams(self.name,self.weights/other)
        raise TypeError(f"Cannot divide a 'LinearParams' with {type(other)}")
    def __pow__(self,factor):
        return LinearParams(self.name,self.weights**factor)
    def tree_flatten(self):
        children = (self.weights,)
        aux_data = (self.name,)
        return (children, aux_data)
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*aux_data,*children)

In [25]:
@register_pytree_node_class
class FeedForwardParams:
    """
    A class to represent the parameters of a feedforward neural network layer, specifically the weights and biases,
    which are stored as Parameter objects.

    Attributes:
    -----------
    name : str
        The name of the feedforward layer parameters.
    weights : Parameter
        The weights of the feedforward layer, stored as a Parameter object.
    bias : Parameter
        The bias of the feedforward layer, stored as a Parameter object.
    """
    def __init__(self,name,weights,bias):
        self.name = name
        if isinstance(weights,Parameter):
            self.weights = weights
        else:
            self.weights = Parameter("W",weights)
        if isinstance(bias,Parameter):
            self.bias = bias
        else:
            self.bias = Parameter("bais",bias)
    def __sub__(self,other):
        if isinstance(other,FeedForwardParams) :
            return FeedForwardParams(self.name,self.weights-other.weights,self.bias-other.bias)
        raise TypeError(f"unsupported operand type(s) for -: {type(other)} and 'FeedForwardParams'")
    def __add__(self,other):
        if isinstance(other,FeedForwardParams):
            return FeedForwardParams(self.name,self.weights+other.weights,self.bias+other.bias)
        if isinstance(other,float):
            return FeedForwardParams(self.name,self.weights+other,self.bias+other)
        raise TypeError(f"unsupported operand type(s) for +: {type(other)} and 'FeedForwardParams'")
    def __mul__(self,other):
        if isinstance(other,float):
            return FeedForwardParams(self.name,self.weights*other,self.bias*other)
        raise TypeError(f"Cannot multiply a 'FeedForwardParams' with {type(other)}")
    def __rmul__(self,other):
        if isinstance(other,float):
            return FeedForwardParams(self.name,self.weights*other,self.bias*other)
        raise TypeError(f"Cannot multiply a 'FeedForwardParams' with {type(other)}")
    def __truediv__(self,other):
        if isinstance(other,FeedForwardParams):
            return FeedForwardParams(self.name,self.weights/other.weights,self.bias/other.bias)
        if isinstance(other,float):
            return FeedForwardParams(self.name,self.weights/other,self.bias/other)
        raise TypeError(f"Cannot divide a 'FeedForwardParams' with {type(other)}")
    def __pow__(self,factor):
        return FeedForwardParams(self.name,self.weights**factor,self.bias**factor)
    def tree_flatten(self):
        children = (self.weights,self.bias,)
        aux_data = (self.name,)
        return (children, aux_data)
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*aux_data,*children)
    

In [26]:
@register_pytree_node_class
class AttentionParams:
    """
    A class to represent the parameters of an attention mechanism, specifically the weights for key, query, and value,
    which are stored as Parameter objects.

    Attributes:
    -----------
    name : str
        The name of the attention parameters.
    w_k : Parameter
        The weights for the key, stored as a Parameter object.
    w_q : Parameter
        The weights for the query, stored as a Parameter object.
    w_v : Parameter
        The weights for the value, stored as a Parameter object.
    """
    def __init__(self,name,w_k,w_q,w_v):
        self.name = name
        if isinstance(w_k,Parameter):
            self.w_k = w_k
        else:
            self.w_k = Parameter("w_k",w_k)
        if isinstance(w_q,Parameter):
            self.w_q = w_q
        else:
            self.w_q = Parameter("w_q",w_q)
        if isinstance(w_v,Parameter):
            self.w_v = w_v
        else:
            self.w_v = Parameter("w_v",w_v)
        
    def __sub__(self,other):
        if isinstance(other,AttentionParams):
            return AttentionParams(self.name,self.w_k-other.w_k,self.w_q-other.w_q,self.w_v-other.w_v)
        raise TypeError(f"unsupported operand type(s) for -: {type(other)} and 'AttentionParams'")
    def __add__(self,other):
        if isinstance(other,AttentionParams):
            return AttentionParams(self.name,self.w_k+other.w_k,self.w_q+other.w_q,self.w_v+other.w_v)
        if isinstance(other,float):
            return AttentionParams(self.name,self.w_k+other,self.w_q+other,self.w_v+other)
        raise TypeError(f"unsupported operand type(s) for +: {type(other)} and 'AttentionParams'")
    def __mul__(self,other):
        if isinstance(other,float):
            return AttentionParams(self.name,self.w_k*other,self.w_q*other,self.w_v*other)
        raise TypeError(f"Cannot multiply a 'AttentionParams' with {type(other)}")
    def __truediv__(self,other):
        if isinstance(other,AttentionParams):
            return AttentionParams(self.name,self.w_k/other.w_k,self.w_q/other.w_q,self.w_v/other.w_v)
        if isinstance(other,float):
            return AttentionParams(self.name,self.w_k/other,self.w_q/other,self.w_v/other)
        raise TypeError(f"Cannot divide a 'AttentionParams' with {type(other)}")
    def __rmul__(self,other):
        if isinstance(other,float):
            return AttentionParams(self.name,self.w_k*other,self.w_q*other,self.w_v*other)
        raise TypeError(f"Cannot multiply a 'AttentionParams' with {type(other)}")
    def __pow__(self,factor):
        return AttentionParams(self.name,self.w_k**factor,self.w_q**factor,self.w_v**factor)
    def tree_flatten(self):
        children = (self.w_k,self.w_q,self.w_v,)
        aux_data = (self.name,)
        return (children, aux_data)
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*aux_data,*children)

In [27]:
@register_pytree_node_class
class MultiHeadAttentionParams:
    """
    A class to represent the parameters of a multi-head attention mechanism, including the output weights
    and a list of individual attention heads.

    Attributes:
    -----------
    name : str
        The name of the multi-head attention parameters.
    weights : Parameter
        The output weights of the multi-head attention, stored as a Parameter object.
    heads : list[AttentionParams]
        A list of AttentionParams objects representing the individual attention heads.
    num_heads : int
        The number of attention heads.
    """
    def __init__(self,name,weights,heads:list[AttentionParams]):
        self.name = name
        if isinstance(weights,Parameter):
            self.weights = weights
        else:
            self.weights = Parameter("Wo",weights)
        self.heads = heads
        self.num_heads = len(heads)
    def add(self,head1,head2):
        return head1+head2
    def subtract(self,head1,head2):
        return head1-head2
    def multiply(self,val,head):
        return val*head
    def divide(self,head,val):
        return head/val
    def pow(self,val,head):
        return head**val
    def __sub__(self,other):
        if isinstance(other,MultiHeadAttentionParams):
            heads = list(map(self.subtract,self.heads,other.heads))
            return MultiHeadAttentionParams(self.name,self.weights-other.weights,heads)
        raise TypeError(f"unsupported operand type(s) for -: {type(other)} and 'MultiHeadAttentionParams'")
    def __add__(self,other):
        if isinstance(other,MultiHeadAttentionParams) :
            heads = list(map(self.add,self.heads,other.heads))
            return MultiHeadAttentionParams(self.name,self.weights+other.weights,heads)
        if isinstance(other,float):
            heads = list(map(self.add,self.heads,[other]*self.num_heads))
            return MultiHeadAttentionParams(self.name,self.weights+other,heads)
        raise TypeError(f"unsupported operand type(s) for +: {type(other)} and 'MultiHeadAttentionParams'")
    def __mul__(self,other):
        if isinstance(other,float):
            heads = list(map(self.multiply,[other]*self.num_heads,self.heads))
            return MultiHeadAttentionParams(self.name,self.weights*other,heads)
        raise TypeError(f"Cannot multiply a 'MultiHeadAttentionParams' with {type(other)}")
    def __truediv__(self,other):
        if isinstance(other,MultiHeadAttentionParams) :
            heads = list(map(self.divide,self.heads,other.heads))
            return MultiHeadAttentionParams(self.name,self.weights/other.weights,heads)
        if isinstance(other,float):
            heads = list(map(self.divide,self.heads,[other]*self.num_heads))
            return MultiHeadAttentionParams(self.name,self.weights/other,heads)
        raise TypeError(f"Cannot multiply a 'MultiHeadAttentionParams' with {type(other)}")
    def __rmul__(self,other):
        if isinstance(other,float):
            heads = list(map(self.multiply,[other]*self.num_heads,self.heads))
            return MultiHeadAttentionParams(self.name,self.weights*other,heads)
        raise TypeError(f"Cannot multiply a 'MultiHeadAttentionParams' with {type(other)}")
    def __pow__(self,factor):
        heads = list(map(self.pow,[factor]*self.num_heads,self.heads))
        return MultiHeadAttentionParams(self.name,self.weights**factor,heads)
    def tree_flatten(self):
        children = (self.weights,self.heads,)
        aux_data = (self.name,)
        return (children, aux_data)
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*aux_data,*children)

In [28]:
@register_pytree_node_class
class ModuleParams:
    """
    A class to represent the parameters of a module, including a list of its components.

    Attributes:
    -----------
    name : str
        The name of the module parameters.
    components : list
        A list of components that make up the module.
    num_comps : int
        The number of components in the module.
    """
    def __init__(self,name,components):
        self.name = name
        self.components = components
        self.num_comps = len(components)
    def multiply(self,val,comp):
        return val*comp
    def subtract(self,comp1,comp2):
        return comp1-comp2
    def add(self,comp1,comp2):
        return comp1+comp2
    def pow(self,val,comp):
        return comp**val
    def divide(self,comp,val):
        return comp/val
    def __sub__(self,other):
        if isinstance(other,ModuleParams):
            comps = list(map(self.subtract,self.components,other.components))
            return ModuleParams(self.name,comps)
        raise TypeError(f"unsupported operand type(s) for -: {type(other)} and 'ModuleParams'")
    def __add__(self,other):
        if isinstance(other,ModuleParams) :
            comps = list(map(self.add,self.components,other.components))
            return ModuleParams(self.name,comps)
        if isinstance(other,float):
            comps = list(map(self.add,self.components,[other]*self.num_comps))
            return ModuleParams(self.name,comps)
        raise TypeError(f"unsupported operand type(s) for +: {type(other)} and 'ModuleParams'")
    def __mul__(self,other):
        if isinstance(other,float):
            comps = list(map(self.multiply,[other]*self.num_comps,self.components))
            return ModuleParams(self.name,comps)
        raise TypeError(f"Cannot multiply a 'ModuleParams' with {type(other)}")
    def __truediv__(self,other):
        if isinstance(other,ModuleParams) :
            comps = list(map(self.divide,self.components,other.components))
            return ModuleParams(self.name,comps)
        if isinstance(other,float):
            comps = list(map(self.divide,self.components,[other]*self.num_comps))
            return ModuleParams(self.name,comps)
        raise TypeError(f"Cannot divide a 'ModuleParams' with {type(other)}")
    def __rmul__(self,other):
        if isinstance(other,float):
            comps = list(map(self.multiply,[other]*self.num_comps,self.components))
            return ModuleParams(self.name,comps)
        raise TypeError(f"Cannot multiply a 'ModuleParams' with {type(other)}")
    def __pow__(self,factor):
        comps = list(map(self.pow,[factor]*self.num_comps,self.components))
        return ModuleParams(self.name,comps)
    def tree_flatten(self):
        children = (self.components,)
        aux_data = (self.name,)
        return (children, aux_data)
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*aux_data,*children)

In [29]:
class Dropout:
    """
    A class to represent dropout regularization in a neural network.

    Attributes:
    -----------
    dropout_p : float
        The probability of dropping out a unit, defaulting to 0.2.
    seed : int
        The seed for random number generation, defaulting to 0.
    """
    def __init__(self,dropout_p,seed=0):
        self.dropout_p = 0.2
        self.seed = seed
    def predict(self,x):
        import random as rnd
        _,key_ = random.split(random.key(rnd.randint(0,1000)))
        mask_ = random.bernoulli(key_,1-self.dropout_p,shape=x.shape)
        dropout_out = mask_*x
        scale = 1/(1-self.dropout_p)
        return dropout_out*scale
    def batched_predict(self,x):
        predictor = vmap(self.predict,in_axes=(0))
        return predictor(x)
    def __call__(self,x):
        if len(x.shape)>1:
            return self.batched_predict(x)
        return self.predict(x)
        


In [30]:
@register_pytree_node_class
class LinearLayer:
    """
    A class to represent a linear layer in a neural network.

    Attributes:
    -----------
    in_units : int
        The number of input units for the linear layer.
    out_units : int
        The number of output units for the linear layer.
    params : LinearParams
        The parameters of the linear layer.
    key : jax.random.PRNGKey
        The random key for parameter initialization.

    Class Methods:
    --------------
    initiate_params(cls, name, in_units, out_units, key, scale=1e-2):
        Initializes the parameters for the linear layer.
    """
    @classmethod
    def initiate_params(cls,name,in_units,out_units,key,scale=1e-2):
        w_key,_= random.split(key,2)
        initializer = jax.nn.initializers.he_normal()
        params = {}
        #params["W"] = random.normal(w_key,shape = (n_vocab,embedding_dims),dtype=jnp.float32)*scale
        initializer = jax.nn.initializers.he_normal()
        #params["W"] = initializer(w_key,shape = (n_vocab,embedding_dims),dtype=jnp.float32)*scale
        params = LinearParams(name,initializer(w_key,shape = (in_units,out_units),dtype=jnp.float32)*scale)
        return params
    def __init__(self,name,in_units,out_units,params=None):
        self.in_units = in_units
        self.out_units = out_units
        self.params = params
        self.key = random.key(210)
        if params==None:
            self.params = LinearLayer.initiate_params(name,self.n_vocab,self.embedding_dims,self.key)
    def predict(self,x):
        x = jnp.matmul(x,self.params.weights.value)
        return x
    def batched_predict(self,x):
        predictor = vmap(self.predict,in_axes=[0])
        return predictor(x)
    def __call__(self,x):
        if len(x.shape)>2:
            return self.batched_predict(x)
        return self.predict(x)
        #print(x)
        
    def tree_flatten(self):
        children = (self.params,)
        aux_data = (self.in_units,self.out_units)
        return (children, aux_data)
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*aux_data,*children)

In [31]:
@register_pytree_node_class
class EmbeddingLayer:
    """
    A class to represent an embedding layer in a neural network, including parameter initialization 
    and positional encoding.

    Attributes:
    -----------
    n_vocab : int
        The size of the vocabulary for the embedding layer.
    embedding_dims : int
        The dimensionality of the embeddings.
    params : LinearParams
        The parameters of the embedding layer.
    key : jax.random.PRNGKey
        The random key for parameter initialization.

    Class Methods:
    --------------
    initiate_params(cls, name, n_vocab, embedding_dims, key, scale=1e-1):
        Initializes the parameters for the embedding layer.
    positional_enc(cls, emb_dims, seq_len):
        Computes positional encodings for a given sequence length and embedding dimensions.

    Methods:
    --------
    one_hot(self, x, max):
        Converts input indices into one-hot encoded vectors.
    """
    @classmethod
    def initiate_params(cls,name,n_vocab,embedding_dims,key,scale=1e-1):
        w_key,_= random.split(key,2)
        initializer = jax.nn.initializers.he_normal()
        params = {}
        #params["W"] = random.normal(w_key,shape = (n_vocab,embedding_dims),dtype=jnp.float32)*scale
        initializer = jax.nn.initializers.he_normal()
        #params["W"] = initializer(w_key,shape = (n_vocab,embedding_dims),dtype=jnp.float32)*scale
        params = LinearParams(name,initializer(w_key,shape = (n_vocab,embedding_dims),dtype=jnp.float32)*scale)
        return params   
    @classmethod
    def positional_enc(cls,emb_dims,seq_len):
        pos = jnp.arange(seq_len)[:, jnp.newaxis]
        pe = jnp.zeros((seq_len,emb_dims))
        div_terms = jnp.exp(jnp.arange(0, emb_dims, 2) * -(jnp.log(10000.0) / emb_dims))
        pe = pe.at[:, 0::2].set(jnp.sin(pos*div_terms))
        pe = pe.at[:, 1::2].set(jnp.cos(pos*div_terms))
        return pe
    def one_hot(self,x,max):
        return jnp.array(x[:,:,None]==jnp.arange(max),dtype=jnp.float32)
    def __init__(self,name,n_vocab,embedding_dims,params=None):
        self.n_vocab = n_vocab
        self.embedding_dims = embedding_dims
        self.params = params
        self.key = random.key(210)
        if params==None:
            self.params = EmbeddingLayer.initiate_params(name,self.n_vocab,self.embedding_dims,self.key)
    def predict(self,x,mask):
        seq_len = x.shape[-1]
        x = self.one_hot(x,self.n_vocab)
        x = jnp.matmul(x,self.params.weights.value)+EmbeddingLayer.positional_enc(self.embedding_dims,seq_len)
        mask = jnp.expand_dims(mask,axis=-1)
        x=x*mask+jnp.ones(shape=mask.shape)*1e-12
        return x
    def batched_predict(self,x,mask):
        predictor = vmap(self.predict,in_axes=[0,0])
        return predictor(x)
    def __call__(self,x,mask):
        if len(x.shape)>2:
            return self.batched_predict(x,mask)
        return self.predict(x,mask)
        #print(x)
        
    def tree_flatten(self):
        children = (self.params,)
        aux_data = (self.n_vocab,self.embedding_dims)
        return (children, aux_data)
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*aux_data,*children)

In [32]:
def relu(x):
    """
    Applies the Rectified Linear Unit (ReLU) activation function element-wise.

    The ReLU function is defined as:
        ReLU(x) = max(0, x)

    Parameters:
    -----------
    x : array-like
        The input array or tensor to which the ReLU function will be applied.

    Returns:
    --------
    jnp.ndarray
        An array or tensor of the same shape as `x`, with the ReLU function applied element-wise.
    """
    return jnp.maximum(0, x)

In [33]:
@register_pytree_node_class
class FeedForward:
    """
    A class to represent a feedforward neural network layer with configurable activation function, weights, and biases.

    Attributes:
    -----------
    activation : callable
        The activation function to apply after the linear transformation. Defaults to the identity function.
    units : int
        The number of units (or neurons) in the feedforward layer.
    key : jax.random.PRNGKey
        The random key for parameter initialization.
    d_model : int
        The dimensionality of the input to the feedforward layer.
    params : FeedForwardParams
        The parameters of the feedforward layer, including weights and biases.

    Class Methods:
    --------------
    initiate_params(name, input_shape, units, key, scale=1e-4):
        Initializes the parameters for the feedforward layer.
    """
    def initiate_params(name,input_shape,units,key,scale=1e-4):
        w_key,b_key = random.split(key,2)
        params = {}
        #params["W"] = random.normal(w_key,shape = (input_shape,units),dtype=jnp.float32)*scale
        #params["b"] = random.normal(b_key,shape = (units,))*scale
        
        initializer = jax.nn.initializers.he_normal()
        #params["W"] = initializer(w_key,shape = (input_shape,units),dtype=jnp.float32)*scale
        params = FeedForwardParams(name,
                                   weights = initializer(w_key,shape = (input_shape,units),dtype=jnp.float32)*scale,
                                   bias = random.normal(b_key,shape = (units,))*scale)
        #params["b"] = initializer(b_key,shape = (units,))*scale
        return params
    def __init__(self,name,d_model,units,activation=lambda x:x,params=None):
        self.activation = activation
        self.units = units
        self.key = random.key(210)
        self.d_model = d_model
        if params == None:
            self.params = FeedForward.initiate_params(name,d_model,self.units,self.key)
        else:
            self.params = params
    def predict(self,input):
        return self.activation(jnp.matmul(input,self.params.weights.value)+self.params.bias.value)
    def batched_predict(self,inputs):
        predictor = vmap(self.predict,in_axes = (0))
        return predictor(inputs)
    def __call__(self,input):
        if len(input.shape)>1:
            return self.batched_predict(input)
        return self.predict(input)
    def tree_flatten(self):
        children = (self.params,)
        aux_data = (self.d_model,self.units,self.activation)
        return (children, aux_data)
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*aux_data,*children)

In [34]:
@register_pytree_node_class
class AttentionHead:
    """
    A class to represent an individual attention head in a multi-head attention mechanism.

    Attributes:
    -----------
    d : int
        The dimensionality of the query, key, and value vectors for the attention head.
    d_model : int
        The dimensionality of the input to the attention head.
    key : jax.random.PRNGKey
        The random key for parameter initialization.
    params : AttentionParams
        The parameters of the attention head, including weights for query, key, and value.

    Class Methods:
    --------------
    initiate_params(name, input_shape, units, key, scale=1e-2):
        Initializes the parameters for the attention head.
    """
    def initiate_params(name,input_shape,units,key,scale=1e-2):
        q_key,k_key,v_key = random.split(key,3)
        
        params = {}
        # params["Wq"] = random.normal(q_key,shape = (input_shape,units),dtype=jnp.float32)*scale
        # params["Wk"] = random.normal(k_key,shape = (input_shape,units),dtype=jnp.float32)*scale
        # params['Wv'] = random.normal(v_key,shape = (input_shape,units),dtype=jnp.float32)*scale
        initializer = jax.nn.initializers.he_normal()
        # params["Wq"] = initializer(q_key,shape = (input_shape,units),dtype=jnp.float32)*scale
        # params["Wk"] = initializer(k_key,shape = (input_shape,units),dtype=jnp.float32)*scale
        # params['Wv'] = initializer(v_key,shape = (input_shape,units),dtype=jnp.float32)*scale
        params = AttentionParams(name=name,
                                 w_q = initializer(q_key,shape = (input_shape,units),dtype=jnp.float32)*scale,
                                 w_k = initializer(k_key,shape = (input_shape,units),dtype=jnp.float32)*scale,
                                 w_v = initializer(v_key,shape = (input_shape,units),dtype=jnp.float32)*scale
                                )
                                 
        return params
    def __init__(self,name,d,d_model,params=None):
        self.d = d
        self.d_model = d_model 
        self.key = random.key(210)
        self.params = params
        if params ==None:
            self.params = AttentionHead.initiate_params(name,self.d_model,self.d,self.key)
    def predict(self,x_q,x_k,x_v,mask,decoder=False):
        query = jnp.matmul(x_q,self.params.w_q.value)
        key = jnp.matmul(x_k,self.params.w_k.value)
        value = jnp.matmul(x_v,self.params.w_v.value)
        #print("Attenion Shapes:",query.shape,key.shape,value.shape)
        attn_scores = jnp.matmul(query,key.T)/jnp.sqrt(self.d)
        if mask != None:
            mask = jnp.expand_dims(mask,axis=0)
            #print(mask*mask.T)
            attn_scores = attn_scores*(mask*mask.T) +(mask*mask.T!=1)*(-1e-20)
        #print(attn_scores)
        #print(attn_scores.shape)
        softmaxed_attn = jax.nn.softmax(attn_scores)
        if mask != None:
            softmaxed_attn = softmaxed_attn*(mask*mask.T) +(mask*mask.T!=1)*(1e-32)
        #softmaxed_attn = jnp.nan_to_num(softmaxed_attn)
        #print(softmaxed_attn)
        if decoder:
            softmaxed_attn = softmaxed_attn*jnp.triu(jnp.ones(attn_scores.shape))
        #print(softmaxed_attn)
        #print("Value Matrix:",value.shape)
        return jnp.matmul(softmaxed_attn,value)
    def batched_predict(self,x_q,x_k,x_v,mask,decoder=False):
        predictor = vmap(self.predict,in_axes = (0,0,0,0,None))
        return predictor(x_q,x_k,x_v,mask,decoder)
    def __call__(self,x_q,x_k,x_v,mask,decoder=False):
        if len(x_q.shape)>1:
            #print(x_q.shape)
            return self.batched_predict(x_q,x_k,x_v,mask,decoder)
        return self.predict(x_q,x_k,x_v,mask,decoder)
    def tree_flatten(self):
        children = (self.params,)
        aux_data = (self.d,self.d_model)
        return (children, aux_data)
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*aux_data,*children)

In [35]:
@register_pytree_node_class
class MultiHeadAttention:
    """
    A class to represent a multi-head attention mechanism in a neural network.

    Attributes:
    -----------
    h : int
        The number of attention heads.
    d_model : int
        The dimensionality of the input and output for the multi-head attention.
    key : jax.random.PRNGKey
        The random key for parameter initialization.
    d : int
        The dimensionality of each attention head.
    params : MultiHeadAttentionParams
        The parameters of the multi-head attention, including output weights and individual attention heads.
    attentionHeads : list[AttentionHead]
        A list of AttentionHead instances representing each attention head.

    Class Methods:
    --------------
    initiate_params(name, num_heads, input_shape, units, key, scale=1e-3):
        Initializes the parameters for the multi-head attention mechanism.
    """
    def initiate_params(name,num_heads,input_shape,units,key,scale=1e-3):
        o_key,*h_key = random.split(key,num_heads+1)
        #print(o_key,h_key)
        params = {}
        #params['Wo'] = random.normal(o_key,shape = (input_shape,units),dtype=jnp.float32)*scale
        initializer = jax.nn.initializers.he_normal()
        #params['Wo'] = initializer(o_key,shape = (input_shape,units),dtype=jnp.float32)*scale
        params = MultiHeadAttentionParams(name,
                                          weights = initializer(o_key,shape = (input_shape,units),dtype=jnp.float32)*scale,
                                          heads = [AttentionHead.initiate_params(f"H{i}",
                                                                                 input_shape,
                                                                                 input_shape//num_heads,
                                                                                 h_key[i]) 
                                                   for i in range(num_heads)
                                                  ]
                                         )
        return params
    def __init__(self,name,h,d_model,params=None):
        self.h = h
        self.d_model = d_model
        self.key = random.key(210)
        self.d = d_model//h
        self.params = params
        if params ==None:
            self.params = MultiHeadAttention.initiate_params(name,self.h,self.d_model,self.d_model,self.key)
        
        if self.d_model%self.h!=0:
            raise "D_model not divisible by number of heads"
        self.attentionHeads = [AttentionHead(f"H{i}",
                                             self.d,
                                             self.d_model,
                                             self.params.heads[i]) 
                               for i in range(self.h)]
        
    #def predict(self,x_q,x_k,x_v,mask=None,decoder=False):
        #return jnp.matmul(self.params['Wo'],jnp.concat([head.predict(x_q_i,x_k_i,x_v_i,mask,decoder) for head,x_q_i,x_k_i,x_v_i in zip(self.attentionHeads,[x_q]*8,[x_k]*8,[x_v]*8)]))
    def calc_attentions(self,x_q,x_k,x_v,mask=None,decoder=False):
        concat_attn = jnp.concat([head.batched_predict(x_q_i,x_k_i,x_v_i,mask,decoder) 
                                  for head,x_q_i,x_k_i,x_v_i in 
                                  zip(self.attentionHeads,[x_q]*8,[x_k]*8,[x_v]*8)],
                                 axis=-1)
        return concat_attn
    def predict(self,attns):
        return jnp.matmul(attns,self.params.weights.value)
    def batched_predict(self,attns):
        predictor = vmap(self.predict,in_axes = (0))
        return predictor(attns)
    def __call__(self,x_q,x_k,x_v,mask=None,decoder=False):
        attns = self.calc_attentions(x_q,x_k,x_v,mask,decoder)
        if len(x_q.shape)>1: 
            return self.batched_predict(attns)
        return self.predict(attns)
    def tree_flatten(self):
        children = (self.params,)
        aux_data = (self.h,self.d_model)
        return (children, aux_data)
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*aux_data,*children)

In [36]:
@register_pytree_node_class
class EncoderLayer:
    """
    A class to represent an encoder layer in a transformer model, including multi-head attention, 
    feedforward layers, and dropout.

    Attributes:
    -----------
    d_model : int
        The dimensionality of the input and output for the encoder layer.
    d_ff : int
        The dimensionality of the hidden layer in the feedforward network.
    num_heads : int
        The number of attention heads in the multi-head attention mechanism.
    params : ModuleParams
        The parameters of the encoder layer, including multi-head attention and feedforward layers.
    key : jax.random.PRNGKey
        The random key for parameter initialization.
    ff1 : FeedForward
        The first feedforward network in the encoder layer.
    ff2 : FeedForward
        The second feedforward network in the encoder layer.
    mha : MultiHeadAttention
        The multi-head attention mechanism in the encoder layer.
    dropout : Dropout
        The dropout regularization applied to the encoder layer.

    Class Methods:
    --------------
    layer_normalization(output, epsilon=1e-9):
        Applies layer normalization to the input tensor.
    """
    def layer_normalization(output, epsilon=1e-9):
        """
        Applies layer normalization to the input tensor.

        Parameters:
        -----------
        output : jnp.ndarray
            The input tensor to be normalized.
        epsilon : float, optional
            A small constant added to the variance to avoid division by zero. Defaults to 1e-9.

        Returns:
        --------
        jnp.ndarray
            The normalized tensor.
        """
        H = output.shape[-1]
        mean = jnp.expand_dims(output.mean(axis=-1), axis=-1)
        std = jnp.expand_dims(output.std(axis=-1), axis=-1)
        output = (output - mean) / (std + epsilon)
        return output

    def __init__(self,name,d_model,d_ff,num_heads,params=None):
        self.d_model = d_model
        self.d_ff = d_ff
        self.num_heads = num_heads
        self.params = params
        self.key = random.key(210)
        if params==None:
            ff1_key,ff2_key,mha_key = random.split(self.key,3)
            self.params = ModuleParams(name,
                                       [MultiHeadAttention.initiate_params('mha',num_heads,d_model,d_model,mha_key),
                                        FeedForward.initiate_params('ff1',d_model,d_ff,ff1_key),
                                        FeedForward.initiate_params('ff2',d_ff,d_model,ff2_key)])
        
        self.ff1 = FeedForward("ff1",d_model,d_ff,params=self.params.components[1])
        self.ff2 = FeedForward("ff2",d_ff,d_model,params=self.params.components[2])
        #self.__name__ = f"EncoderLayer{num}"
        self.mha = MultiHeadAttention("mha",num_heads,d_model,self.params.components[0])
        self.dropout = Dropout(0.2)
       
        #print(self.params)
    def predict(self,input,mask):
        attentions = self.mha(input,input,input,mask)
        #print(mask)
        attentions = self.dropout(attentions)
        #print("Attentions")
        #print(attentions)
        x = EncoderLayer.layer_normalization(input+attentions)
        #print("x+attentions")
       # print(x)
        ff_ = self.ff2(self.ff1(x))
        ff_ = self.dropout(ff_)
        #print("x+ff_")
        #print(x+ff_)
        x = EncoderLayer.layer_normalization(x+ff_)
        #print(x)
        return x
    # def batched_predict(self,inputs,mask):
    #     predictor = vmap(self.predict,in_axes=(0,0))
    #     return predictor(inputs,mask)
    def __call__(self,inputs,mask):
        # if len(inputs.shape)>1:
        #     return self.batched_predict(inputs,mask)
        return self.predict(inputs,mask)
    def tree_flatten(self):
        children = (self.params,)
        aux_data = (self.d_model,self.d_ff,self.num_heads)
        return (children, aux_data)
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*aux_data,*children)
        

In [37]:
#EncoderLayer.layer_normalization(random.normal(random.key(210),shape=(64,128,512)))

In [38]:
class EncoderLayerParams:
    """
    A class to encapsulate the parameters for an encoder layer in a transformer model.

    Attributes:
    -----------
    params : ModuleParams
        The parameters of the encoder layer, including multi-head attention and feedforward layers.

    Parameters:
    -----------
    name : str
        The name associated with the encoder layer parameters.
    d_model : int
        The dimensionality of the input and output for the encoder layer.
    d_ff : int
        The dimensionality of the hidden layer in the feedforward network.
    num_heads : int
        The number of attention heads in the multi-head attention mechanism.
    key : jax.random.PRNGKey
        The random key for parameter initialization.
    """
    def __init__(self, name, d_model, d_ff, num_heads, key):
        ff1_key, ff2_key, mha_key = random.split(key, 3)
        self.params = ModuleParams(
            name,
            [
                MultiHeadAttention.initiate_params('mha', num_heads, d_model, d_model, mha_key),
                FeedForward.initiate_params('ff1', d_model, d_ff, ff1_key),
                FeedForward.initiate_params('ff2', d_ff, d_model, ff2_key)
            ]
        )
        

class EncoderParams:
    """
    A class to encapsulate the parameters for an encoder in a transformer model, consisting of multiple encoder layers.

    Attributes:
    -----------
    params : ModuleParams
        The parameters of the encoder, including multiple encoder layers.

    Parameters:
    -----------
    name : str
        The name associated with the encoder parameters.
    d_model : int
        The dimensionality of the input and output for each encoder layer.
    d_ff : int
        The dimensionality of the hidden layer in the feedforward network within each encoder layer.
    num_heads : int
        The number of attention heads in the multi-head attention mechanism of each encoder layer.
    num_layers : int
        The number of encoder layers in the encoder.
    key : jax.random.PRNGKey
        The random key for parameter initialization.
    """
    def __init__(self, name, d_model, d_ff, num_heads, num_layers, key):
        keys = random.split(key, num_layers)
        self.params = ModuleParams(
            name,
            [
                EncoderLayerParams(f"L{i}", d_model, d_ff, num_heads, key_).params
                for i, key_ in enumerate(keys)
            ]
        )
        
@register_pytree_node_class
class Encoder:
    """
    A class to represent an encoder in a transformer model, consisting of multiple encoder layers.

    Attributes:
    -----------
    d_model : int
        The dimensionality of the input and output for each encoder layer.
    d_ff : int
        The dimensionality of the hidden layer in the feedforward network within each encoder layer.
    num_heads : int
        The number of attention heads in the multi-head attention mechanism of each encoder layer.
    num_layers : int
        The number of encoder layers in the encoder.
    key : jax.random.PRNGKey
        The random key for parameter initialization.
    params : ModuleParams
        The parameters of the encoder, including multiple encoder layers.
    layers : list[EncoderLayer]
        A list of `EncoderLayer` instances representing each layer in the encoder.

    Parameters:
    -----------
    name : str
        The name associated with the encoder.
    d_model : int
        The dimensionality of the input and output for each encoder layer.
    d_ff : int
        The dimensionality of the hidden layer in the feedforward network within each encoder layer.
    num_heads : int
        The number of attention heads in the multi-head attention mechanism of each encoder layer.
    num_layers : int
        The number of encoder layers in the encoder.
    params : ModuleParams, optional
        The parameters of the encoder. If not provided, they will be initialized.
    """
    def __init__(self,name,d_model,d_ff,num_heads,num_layers,params=None):
        #self.num= generate_number(num_layers)
        self.d_model=d_model
        self.d_ff=d_ff
        self.num_heads=num_heads
        self.num_layers=num_layers
        self.key = random.key(210)
        self.params = params
        if params==None:
            self.params = ModuleParams(name,
                                   [EncoderLayerParams(f"L{i}",d_model,d_ff,num_heads,key_).params 
                                    for i,key_ in enumerate(keys)])
        self.layers = [EncoderLayer(f"L{i}",d_model,d_ff,num_heads,self.params.components[i]) for i in range(num_layers)]
    def __call__(self,input,mask):
        x = input
        #print("Encoder Input Shape:",x.shape)
        for layer in self.layers:
            x = layer(x,mask)
        return x
    def tree_flatten(self):
        children = (self.params,)
        aux_data = (self.d_model,self.d_ff,self.num_heads,self.num_layers)
        return (children, aux_data)
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*aux_data,*children)

In [56]:
class DecoderLayerParams:
    """
    A class to encapsulate the parameters for a decoder layer in a transformer model.

    Attributes:
    -----------
    params : ModuleParams
        The parameters of the decoder layer, including multi-head attention mechanisms and feedforward layers.

    Parameters:
    -----------
    name : str
        The name associated with the decoder layer parameters.
    d_model : int
        The dimensionality of the input and output for the decoder layer.
    d_ff : int
        The dimensionality of the hidden layer in the feedforward network within the decoder layer.
    num_heads : int
        The number of attention heads in the multi-head attention mechanisms of the decoder layer.
    key : jax.random.PRNGKey
        The random key for parameter initialization.
    """
    def __init__(self,name,d_model,d_ff,num_heads,key):
        self.params = {}
        ff1_key,ff2_key,e_mha_key,d_mha_key = random.split(key,4)
        self.params = ModuleParams(name,
                                [MultiHeadAttention.initiate_params('d_mha',num_heads,d_model,d_model,d_mha_key),
                                 MultiHeadAttention.initiate_params('e_mha',num_heads,d_model,d_model,e_mha_key),
                                 FeedForward.initiate_params('ff1',d_model,d_ff,ff1_key),
                                 FeedForward.initiate_params('ff2',d_ff,d_model,ff2_key)])

@register_pytree_node_class        
class DecoderLayer:
    """
    A class to represent a decoder layer in a transformer model, including multi-head attention mechanisms,
    feedforward layers, and dropout.

    Attributes:
    -----------
    d_model : int
        The dimensionality of the input and output for the decoder layer.
    num_heads : int
        The number of attention heads in the multi-head attention mechanisms of the decoder layer.
    key : jax.random.PRNGKey
        The random key for parameter initialization.
    params : ModuleParams, optional
        The parameters of the decoder layer. If not provided, they will be initialized.
    ff1 : FeedForward
        The first feedforward network in the decoder layer.
    ff2 : FeedForward
        The second feedforward network in the decoder layer.
    d_mha : MultiHeadAttention
        The multi-head attention mechanism for decoder inputs in the decoder layer.
    e_mha : MultiHeadAttention
        The multi-head attention mechanism for encoder outputs in the decoder layer.
    dropout : Dropout
        The dropout regularization applied to the decoder layer.

    Class Methods:
    --------------
    layer_normalization(output, epsilon=1e-9):
        Applies layer normalization to the input tensor.
    """
    def layer_normalization(output, epsilon=1e-9):
        """
        Applies layer normalization to the input tensor.

        Parameters:
        -----------
        output : jnp.ndarray
            The input tensor to be normalized.
        epsilon : float, optional
            A small constant added to the variance to avoid division by zero. Defaults to 1e-9.

        Returns:
        --------
        jnp.ndarray
            The normalized tensor.
        """
        H = output.shape[-1]
        mean = jnp.expand_dims(output.mean(axis=-1), axis=-1)
        std = jnp.expand_dims(output.std(axis=-1), axis=-1)
        output = (output - mean) / (std + epsilon)
        return output
    def __init__(self,name,d_model,d_ff,num_heads,params=None):
        self.d_model = d_model
        self.num_heads = num_heads
        self.key = random.key(210)
        self.params = params
        if params ==None:
            ff1_key,ff2_key,e_mha_key,d_mha_key = random.split(self.key,4)
            self.params = ModuleParams(name,
                                    [MultiHeadAttention.initiate_params('d_mha',num_heads,d_model,d_model,d_mha_key),
                                     MultiHeadAttention.initiate_params('e_mha',num_heads,d_model,d_model,e_mha_key),
                                     FeedForward.initiate_params('ff1',d_model,d_ff,ff1_key),
                                     FeedForward.initiate_params('ff2',d_ff,d_model,ff2_key)])
        self.ff1 = FeedForward("ff1",d_model,d_ff,params = self.params.components[2])
        self.ff2 = FeedForward("ff2",d_ff,d_model,params = self.params.components[3])
        self.d_mha = MultiHeadAttention("d_mha",num_heads,d_model,params = self.params.components[0])
        self.e_mha = MultiHeadAttention("e_mha",num_heads,d_model,params = self.params.components[1])
        self.dropout = Dropout(0.2)
    def predict(self,input,encoder_output,mask):
        attentions = self.d_mha(input,input,input,mask,decoder=True)
        attentions = self.dropout(attentions)
        x = DecoderLayer.layer_normalization(input+attentions)
        e_attentions = self.e_mha(x,encoder_output,encoder_output,mask)
        e_attentions = self.dropout(e_attentions)
        x = DecoderLayer.layer_normalization(x+e_attentions)
        ff_ = self.ff2(self.ff1(x))
        ff_ = self.dropout(ff_)
        x = DecoderLayer.layer_normalization(x+ff_)
        return x
    # def batched_predict(self,inputs,encoder_output,mask):
    #     predictor = vmap(self.predict,in_axes=(0,0,0))
    #     return predictor(inputs,encoder_output,mask)
    def __call__(self,inputs,encoder_output,mask):
        # if len(inputs.shape)>1:
        #     return self.batched_predict(inputs,encoder_output)
        return self.predict(inputs,encoder_output,mask)
    def tree_flatten(self):
        children = (self.params,)
        aux_data = (self.d_model,self.d_ff,self.num_heads)
        return (children, aux_data)
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*aux_data,*children)
        

In [40]:
class DecoderParams:
    """
    A class to encapsulate the parameters for a decoder in a transformer model, consisting of multiple decoder layers.

    Attributes:
    -----------
    params : ModuleParams
        The parameters of the decoder, including multiple decoder layers.

    Parameters:
    -----------
    name : str
        The name associated with the decoder parameters.
    d_model : int
        The dimensionality of the input and output for each decoder layer.
    d_ff : int
        The dimensionality of the hidden layer in the feedforward network within each decoder layer.
    num_heads : int
        The number of attention heads in the multi-head attention mechanisms of each decoder layer.
    num_layers : int
        The number of decoder layers in the decoder.
    key : jax.random.PRNGKey
        The random key for parameter initialization.
    """

    def __init__(self,name,d_model,d_ff,num_heads,num_layers,key):
        keys = random.split(key,num_layers)
        self.params = ModuleParams(name,
                                   [DecoderLayerParams(f"L{i}",d_model,d_ff,num_heads,key_).params 
                                    for i,key_ in enumerate(keys)])
@register_pytree_node_class
class Decoder:
    """
    A class to represent a decoder in a transformer model, consisting of multiple decoder layers.

    Attributes:
    -----------
    params : ModuleParams
        The parameters of the decoder, including multiple decoder layers.
    keys : list[jax.random.PRNGKey]
        The random keys used for parameter initialization of each decoder layer.
    d_model : int
        The dimensionality of the input and output for each decoder layer.
    d_ff : int
        The dimensionality of the hidden layer in the feedforward network within each decoder layer.
    num_heads : int
        The number of attention heads in the multi-head attention mechanisms of each decoder layer.
    num_layers : int
        The number of decoder layers in the decoder.
    layers : list[DecoderLayer]
        A list of `DecoderLayer` instances representing each layer in the decoder.

    Parameters:
    -----------
    name : str
        The name associated with the decoder.
    d_model : int
        The dimensionality of the input and output for each decoder layer.
    d_ff : int
        The dimensionality of the hidden layer in the feedforward network within each decoder layer.
    num_heads : int
        The number of attention heads in the multi-head attention mechanisms of each decoder layer.
    num_layers : int
        The number of decoder layers in the decoder.
    params : ModuleParams, optional
        The parameters of the decoder. If not provided, they will be initialized.
    """
    def __init__(self,name,d_model,d_ff,num_heads,num_layers,params=None):
        self.params = params
        self.keys = random.split(random.key(251),num_layers)
        self.d_model = d_model
        self.d_ff = d_ff
        self.num_heads = num_heads
        self.num_layers = num_layers
        if params==None:
            self.params = ModuleParams(name,
                                   [DecoderLayerParams(f"L{i}",d_model,d_ff,num_heads,key_).params 
                                    for i,key_ in enumerate(keys)])
        self.layers = [DecoderLayer(f"L{i}",d_model,d_ff,num_heads,self.params.components[i]) for i in range(num_layers)]
    def __call__(self,input,encoder_output,mask):
        x = input
        #print(mask)
        for layer in self.layers:
            x = layer(x,encoder_output,mask)
        return x
    def tree_flatten(self):
        children = (self.params,)
        aux_data = (self.d_model,self.d_ff,self.num_heads,self.num_layers)
        return (children, aux_data)
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*aux_data,*children)

In [41]:
?LinearLayer.initiate_params

[0;31mSignature:[0m [0mLinearLayer[0m[0;34m.[0m[0minitiate_params[0m[0;34m([0m[0mname[0m[0;34m,[0m [0min_units[0m[0;34m,[0m [0mout_units[0m[0;34m,[0m [0mkey[0m[0;34m,[0m [0mscale[0m[0;34m=[0m[0;36m0.01[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m <no docstring>
[0;31mFile:[0m      /tmp/ipykernel_2923/1082525596.py
[0;31mType:[0m      method

In [42]:
class TransformerParams:
    """
    A class to represent the parameters for a transformer model, including embeddings, encoder, decoder, and output linear layer.

    Attributes:
    -----------
    params : dict
        A dictionary containing the parameters of the transformer model. This includes:
        - Embedding parameters for the input and output tokens
        - Encoder parameters
        - Decoder parameters
        - Linear layer parameters for the output projection

    Parameters:
    -----------
    d_model : int
        The dimensionality of the input and output for the transformer model.
    d_ff : int
        The dimensionality of the hidden layer in the feedforward network within each transformer layer.
    num_heads : int
        The number of attention heads in the multi-head attention mechanisms of the encoder and decoder.
    num_layers : int
        The number of layers in both the encoder and decoder.
    n_vocab : int
        The size of the vocabulary for token embeddings.
    key : jax.random.PRNGKey
        The random key for initializing the parameters.
    """
    def __init__(self,d_model,d_ff,num_heads,num_layers,n_vocab,key):
        self.params = {}
        key_emb,key_emb_dec,key_e,key_d,key_l = random.split(key,5)
        #self.params["Embedding"] = EmbeddingLayer.initiate_params(n_vocab,d_model,key_emb)
        self.params   = ModuleParams("Transformer",
                                     [EmbeddingLayer.initiate_params("In_Embedding",n_vocab,d_model,key_emb),
                                     EmbeddingLayer.initiate_params("Out_Embedding",n_vocab,d_model,key_emb_dec),
                                     EncoderParams("Encoder",d_model,d_ff,num_heads,num_layers,key_e).params,
                                     DecoderParams("Decoder",d_model,d_ff,num_heads,num_layers,key_d).params,
                                     LinearLayer.initiate_params("Linear",d_model,n_vocab,key_l)])
        

@register_pytree_node_class
class Transformer:
    """
    A class to represent a transformer model, including embeddings, encoder, decoder, and output linear layer.

    Attributes:
    -----------
    d_model : int
        The dimensionality of the input and output for the transformer model.
    d_ff : int
        The dimensionality of the hidden layer in the feedforward network within each transformer layer.
    num_heads : int
        The number of attention heads in the multi-head attention mechanisms of the encoder and decoder.
    num_layers : int
        The number of layers in both the encoder and decoder.
    n_vocab : int
        The size of the vocabulary for token embeddings.
    logits : bool
        Whether to output logits (if True) or probabilities (if False).
    params : dict, optional
        The parameters of the transformer model. If not provided, they will be initialized.
    key : jax.random.PRNGKey
        The random key used for initializing the parameters.
    in_embedding : EmbeddingLayer
        The embedding layer for input tokens.
    out_embedding : EmbeddingLayer
        The embedding layer for output tokens.
    encoder : Encoder
        The encoder component of the transformer model.
    decoder : Decoder
        The decoder component of the transformer model.
    linear : LinearLayer
        The linear layer used for output projection.
    """
    def __init__(self,d_model,d_ff,num_heads,num_layers,n_vocab,logits=False,params=None,seed=0):
        self.d_model = d_model
        self.d_ff=d_ff
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.n_vocab = n_vocab
        self.logits = logits
        self.params = params
        self.key = random.key(seed)
        if params == None:
            self.params = TransformerParams(d_model,d_ff,num_heads,num_layers,n_vocab,self.key).params
        #self.embedding = EmbeddingLayer(n_vocab,d_model,self.params["Embedding"])
        self.in_embedding = EmbeddingLayer("In_Embedding",n_vocab,d_model,self.params.components[0])
        self.out_embedding = EmbeddingLayer("Out_Embedding",n_vocab,d_model,self.params.components[1])
        self.encoder = Encoder("Encoder",d_model,d_ff,num_heads,num_layers,self.params.components[2])
        self.decoder = Decoder("Decoder",d_model,d_ff,num_heads,num_layers,self.params.components[3])
        self.linear = LinearLayer("Linear",d_model,n_vocab,params=self.params.components[4])
        self.logits = logits
    def update_params(self):
        self.in_embedding = EmbeddingLayer("In_Embedding",self.n_vocab,self.d_model,self.params.components[0])
        self.out_embedding = EmbeddingLayer("Out_Embedding",self.n_vocab,self.d_model,self.params.components[1])
        self.encoder = Encoder("Encoder",self.d_model,self.d_ff,self.num_heads,self.num_layers,self.params.components[2])
        self.decoder = Decoder("Decoder",self.d_model,self.d_ff,self.num_heads,self.num_layers,self.params.components[3])
        self.linear = LinearLayer("Linear",self.d_model,self.n_vocab,params=self.params.components[4])
    def __call__(self,inputs,outputs):
        input_tokens = jnp.array(inputs['token_ids'])
        input_mask = jnp.array(inputs['padding_mask'])
        output_tokens = jnp.array(outputs['token_ids'])
        output_mask = jnp.array(outputs['padding_mask'])
        input_tokens = Padder.left_shift(input_tokens,5)
        input_mask = Padder.left_shift_mask(input_mask)
        embs = self.in_embedding(input_tokens,input_mask)
        #print("In_Embeddings",embs)
        op_embs = self.out_embedding(output_tokens,output_mask)
        #print("Out_Embeddings",op_embs)
        if len(embs.shape)!=3:
            raise "Dimensions of the input must include (Batch,Token Sequence)"
        encoder_output = self.encoder(embs,input_mask)
        #print("Encoder Output:",encoder_output)
        decoder_output = self.decoder(op_embs,encoder_output,output_mask)
        #print("Decoder Output:",decoder_output)
        output = self.linear(decoder_output)
        #print("Linear Output:",output)
        if self.logits:
            return output
        return jax.nn.softmax(output,axis=-1)
    def tree_flatten(self):
        children = (self.params,)
        aux_data = (self.d_model,self.d_ff,self.num_heads,self.num_layers,self.n_vocab,self.logits)
        return (children, aux_data)
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*aux_data,*children)

In [43]:
# class Optimizer:
#     def __init__(self,lr,lambda_):
#         self.lr = lr
#         self.lambda_ = lambda_
#     def update_coder_params(self,t,params,grads):
#         for layer in params:
#             params[layer] = self.update_layer_params(t,
#                 params[layer],
#                 grads[layer]
#             )
#         return params
#     def update_layer_params(self,t,params,grads):
#         for type in params:
#             if 'ff' in type:
#                 params[type] = self.update_basic_params(t,
#                     params[type],
#                     grads[type]
#                 )
#             if 'mha' in type:
#                 params[type] = self.update_mha_params(t,
#                     params[type],
#                     grads[type]
#                 )
#         return params
#     def update_mha_params(self,t,params,grads):
#         params['Wo'] = self.update_params(t,params['Wo'],grads['Wo'])
#         for head in params:
#             if head == 'Wo':
#                 continue
#             for type in params[head]:
#                 #print(head,type)
#                 params[head][type] = self.update_params(t,params[head][type],grads[head][type])
#         return params
#     def update_basic_params(self,t,params,grads):
#         for type in params:
#             params[type]= self.update_params(t,params[type],grads[type])
#         return params
#     def update_params(self,t,params,grads):
#         params = params - self.lr*(grads)
#         return params

In [44]:
class Optimizer:
    def __init__(self,lr,lambda_):
        self.lr = lr
        self.lambda_ = lambda_
    def update_params(self,t,params,grads):
        params = params - self.lr*grads
        return params

In [45]:
class AdamW(Optimizer):
    """
    Implementation of the AdamW optimizer using JAX.

    AdamW is an optimization algorithm that combines the benefits of Adam (adaptive moment estimation) with weight decay regularization.

    Attributes:
    -----------
    beta1 : float
        The exponential decay rate for the first moment estimates.
    beta2 : float
        The exponential decay rate for the second moment estimates.
    epsilon : float
        A small constant to prevent division by zero in the update step.
    m : array-like or None
        The first moment vector, initialized to None until the first update.
    v : array-like or None
        The second moment vector, initialized to None until the first update.

    Parameters:
    -----------
    lr : float
        The learning rate for the optimizer.
    beta1 : float
        The exponential decay rate for the first moment estimates.
    beta2 : float
        The exponential decay rate for the second moment estimates.
    epsilon : float
        A small constant to prevent division by zero.
    lambda_ : float
        The weight decay factor for regularization.

    Methods:
    --------
    SetScheduleMultiplier(t):
        Returns a multiplier for the learning rate schedule at timestep `t`.
    update_params(t, params, grads):
        Updates the model parameters using the AdamW optimization algorithm.
    """
    def __init__(self,lr,beta1,beta2,epsilon,lambda_):
        super().__init__(lr,lambda_)
        self.beta1 = beta1
        self.beta2 = beta2
        self.epsilon = epsilon
        self.m = None
        self.v = None
    def SetScheduleMultiplier(self,t):
        return 0.0001
    def update_params(self,t,params,grads):
        g = grads + self.lambda_*params
        if self.m is None:
            self.m = params*0
            self.v = params*0

        self.m = self.beta1*self.m + (1-self.beta1)*g
        self.v = self.beta2*self.v + (1-self.beta2)*(g**2)
        m_hat = self.m/(1-(self.beta1)**t)
        v_hat = self.v/(1-(self.beta2)**t)
        eta = self.SetScheduleMultiplier(t)
        params = params - eta*((self.lr*m_hat)/((v_hat)**0.5+self.epsilon) + self.lambda_*params)
        return params

In [48]:
class Trainer:
    """
    Trainer class for managing the training process of a model.

    Args:
        model: The model to be trained, which should have an `update_params` method
               to update its parameters and a `params` attribute.
        loss: A function that computes the loss given the model, input data, and target values.
        optimizer: An optimizer instance with an `update_params` method to update the model's
                   parameters based on gradients.
        scheduler: (Optional) A scheduler instance or function to adjust the learning rate or
                   other hyperparameters during training.

    Methods:
        update(t, grads):
            Updates the model parameters using the optimizer and the provided gradients.
        
        train(x, y, epochs, batch_size=None):
            Trains the model for a specified number of epochs.
            
            Args:
                x: A tuple containing the input data for the model. The first element is a dictionary
                   with 'token_ids' and 'padding_mask' for the source input, and the second element
                   is a dictionary for the target input with similar keys.
                y: The target values corresponding to the input data.
                epochs: The number of epochs to train the model.
                batch_size: (Optional) The size of each batch. If not provided, the entire dataset is used
                            as a single batch.
            
            Returns:
                None
    """
    def __init__(self,model,loss,optimizer,schedular=None):
        self.model = model
        self.loss = loss
        #self.lr = lr
        self.optimizer = optimizer
        self.schedular = None
        if schedular!=None:
            self.schedular = schedular(optimizer)
    # def update(self,t,grads):
    #     self.model.params["Encoder"] = self.optimizer.update_coder_params(t,self.model.params["Encoder"],grads["Encoder"])
    #     self.model.params["Decoder"] = self.optimizer.update_coder_params(t,self.model.params["Decoder"],grads["Decoder"])
    #     self.model.params["Linear"] = self.optimizer.update_basic_params(t,self.model.params["Linear"],grads["Linear"])
    #     #self.model.params["Embedding"] = self.optimizer.update_basic_params(t,self.model.params["Embedding"],grads["Embedding"])
    #     self.model.params["In_Embedding"] = self.optimizer.update_basic_params(t,self.model.params["In_Embedding"],grads["In_Embedding"])
    #     self.model.params["Out_Embedding"] = self.optimizer.update_basic_params(t,self.model.params["Out_Embedding"],grads["Out_Embedding"])
    def update(self,t,grads):
        self.model.params = self.optimizer.update_params(t,self.model.params,grads)
        self.model.update_params()
    def train(self,x,y,epochs,batch_size=None):
        data_size = len(x[0]['token_ids'])
        if batch_size== None:
            batch_size = data_size
        # if learning_rate!=None:
        #     self.lr = learning_rate
        t = 0
        for epoch in range(epochs):
            curr = 0
            print("Epoch ",epoch,"\r",flush=True)
            print("Loss:")
            while curr<data_size:
                t+=1
                batch_x = (
                    {'token_ids':x[0]['token_ids'][curr:curr+batch_size],
                     'padding_mask':x[0]['padding_mask'][curr:curr+batch_size]},
                    {'token_ids':x[1]['token_ids'][curr:curr+batch_size],
                     'padding_mask':x[1]['padding_mask'][curr:curr+batch_size]}
                          )
                batch_y = y[curr:curr+batch_size]
                #print(Padder.left_shift(jnp.array(batch_y),5))
                #break
                batch_y = one_hot(Padder.left_shift(jnp.array(batch_y),5),self.model.n_vocab)
                curr = curr+batch_size
                grads = grad(self.loss)(self.model,batch_x,batch_y).params
                #grads = clip_gradients(grads)
                self.update(t,grads)
                if self.schedular!=None:
                    self.schedular.update(512,t,4000)
                    self.optimizer = self.schedular.optimizer
                    #print("Learning Rate:",self.optimizer.lr)
                print(self.loss(self.model,batch_x,batch_y),end="\r",flush=True)
            print(self.loss(self.model,batch_x,batch_y))

In [50]:
def CategoricalCrossEntropy(transformer, x, y):
    """
    Computes the categorical cross-entropy loss for a transformer model.

    Args:
        transformer: A callable model that takes input tokens and output tokens
                     and produces predicted probabilities for each token in the
                     vocabulary.
        x: A tuple containing the input data for the model:
           - The first element is a dictionary with 'token_ids' and 'padding_mask'
             for the source input tokens.
           - The second element is a dictionary with 'token_ids' and 'padding_mask'
             for the target input tokens.
        y: A 3D array of shape (batch_size, sequence_length, n_vocab) representing
           the one-hot encoded target token distributions.

    Returns:
        A scalar value representing the mean categorical cross-entropy loss over
        the batch.
    
    Description:
        This function computes the categorical cross-entropy loss between the predicted
        token probabilities and the true target token distributions. It calculates the
        loss only for tokens where the label is not equal to 5 (which is assumed to be
        a padding token or ignored token). The loss is averaged over the non-masked
        tokens and batch dimensions.

    Example:
        >>> loss = CategoricalCrossEntropy(model, (input_x, target_x), target_y)
    """
    input_tokens = x[0]
    output_tokens = x[1]
    y_hat = transformer(input_tokens, output_tokens)
    labels = jnp.argmax(y, axis=-1)
    mask = labels != 5
    return jnp.mean(
        -(((y * jnp.log(y_hat)).sum(axis=-1)) * mask).sum(axis=-1, keepdims=True) /
        mask.sum(axis=-1, keepdims=True)
    )

In [139]:
def one_hot(x,max):
        return jnp.array(x[:,:,None]==jnp.arange(max),dtype=jnp.float32)


In [51]:
import sentencepiece as spm
import re
import unicodedata
class Tokenizer:
    """
    A class for tokenization and detokenization using SentencePiece.

    This class supports loading a pre-trained SentencePiece model, cleaning text, 
    training a new model, and converting text to tokens and back. 

    Attributes:
        model_prefix (str): Prefix used for the SentencePiece model file.
        model_file (str): Full path to the SentencePiece model file.
        sp (spm.SentencePieceProcessor): Instance of SentencePieceProcessor loaded 
            with the model file.

    Args:
        model_prefix (str): The prefix for the SentencePiece model file. The model file 
                            should be named `<model_prefix>.model`.

    Methods:
        clean_text(sentence: str) -> str:
            Cleans and normalizes the input text by handling contractions, punctuation,
            and whitespace.

        batched_clean_text(x: list[str]) -> list[str]:
            Applies `clean_text` to a list of sentences.

        train(file_name: str, vocab_size: int):
            Trains a SentencePiece model with the given file and vocabulary size, and
            saves it with the specified model prefix.

        __call__(x: str or list[str], out_type=None) -> list[int]:
            Encodes the input text or list of texts into tokens using the trained SentencePiece model.

        detokenize(tokens: list[int]) -> str:
            Converts a list of token IDs back into a human-readable string.

    Raises:
        FileNotFoundError: If the model file does not exist or cannot be loaded.

    Example:
        >>> tokenizer = Tokenizer("my_model_prefix")
        >>> tokenizer.train("data.txt", vocab_size=5000)
        >>> tokens = tokenizer("Hello, world!")
        >>> text = tokenizer.detokenize(tokens)
    """
    import sentencepiece as spm
    def __init__(self,model_prefix):
        self.model_prefix = model_prefix
        self.model_file = self.model_prefix + ".model"
        try:
            self.sp = spm.SentencePieceProcessor(model_file = self.model_file)
        except:
            print("Model File Not Found. Tokenizer must be trained in order to make changes.")
    @classmethod
    def clean_text(cls,sentence):
        pattern = r'[\s]+'
        sentence = re.sub(r"\s+"," ",re.sub(r"([^\'\w])",r" \1",sentence))
        contractions = {"'ve":" have",
                    "'ll":" will",
                    "'m":" am",
                    "'re":" are",
                    "n't":" not",
                    "'d":" had"}
        sentence = unicodedata.normalize("NFD",sentence)
        words = re.split(pattern,sentence)
        for contraction in contractions:
            words = [word.replace(contraction,contractions[contraction]) if contraction in word else word for word in words]
        words = re.split(pattern," ".join(words))
        return " ".join(words)
    @classmethod
    def batched_clean_text(cls,x):
        return [text for text in map(cls.clean_text,x)]
        
    def train(self,file_name,vocab_size):
        spm.SentencePieceTrainer.train(input=file_name,
                                       model_prefix =self.model_prefix,
                                       vocab_size = vocab_size,
                                       control_symbols='<start>,<end>,<pad>')
        self.sp = spm.SentencePieceProcessor(model_file = self.model_file)
    def __call__(self,x,out_type=None):
        if type(x) == str:
            x = Tokenizer.clean_text(x)   
        else:
            x = Tokenizer.batched_clean_text(x)
        #print(x)
        return self.sp.encode(x,out_type)
    def detokenize(self,tokens):
        return self.sp.decode(tokens)

In [52]:
class Padder:
    """
    A class for padding and shifting token sequences for natural language processing tasks.

    This class is designed to pad token sequences to a specified maximum length and to
    perform left shifting operations required for various sequence processing tasks.

    Attributes:
        sp (spm.SentencePieceProcessor): Instance of SentencePieceProcessor used for token operations.
        max_len (int): The maximum length to which sequences will be padded.
        pad_token (int): ID of the padding token.
        start_token (int): ID of the start token.
        end_token (int): ID of the end token.

    Args:
        tokenizer (Tokenizer): An instance of the Tokenizer class to retrieve token IDs.
        max_len (int): Maximum length of sequences after padding.

    Methods:
        add_pads(tokens: list[int], max_len: int = None) -> tuple[list[int], list[int]]:
            Pads the input token sequence to a specified maximum length with special tokens.

        left_shift(tokens: jnp.ndarray, pad_token: int) -> jnp.ndarray:
            Shifts the token sequence to the left by one position, filling the last position
            with the given padding token.

        left_shift_mask(padding: jnp.ndarray) -> jnp.ndarray:
            Shifts the padding mask to the left by one position, filling the last position
            with zeros.

        __call__(tokens: list[int] or list[list[int]]) -> dict:
            Applies padding to a single sequence or a batch of sequences. Returns a dictionary
            with token IDs and padding masks.

    Example:
        >>> tokenizer = Tokenizer("my_model_prefix")
        >>> padder = Padder(tokenizer, max_len=50)
        >>> padded_tokens, pad_mask = padder.add_pads([1, 2, 3, 4])
        >>> shifted_tokens = padder.left_shift(jnp.array([[1, 2, 3, 4]]), pad_token=0)
        >>> padded_batch = padder([[1, 2, 3], [4, 5]])
        >>> print(padded_batch["token_ids"])  # List of padded sequences
        >>> print(padded_batch["padding_mask"])  # List of padding masks
    """
    def __init__(self,tokenizer,max_len):
        self.sp = tokenizer.sp
        self.max_len = max_len
        self.pad_token = self.sp.piece_to_id("<pad>")
        self.start_token = self.sp.piece_to_id("<start>")
        self.end_token = self.sp.piece_to_id("<end>")
    def add_pads(self,tokens,max_len=None):
        if max_len==None:
            max_len = self.max_len
        pad_mask = [1]*(len(tokens)+2)
        if len(tokens)+2>=max_len:
            return [self.start_token]+tokens[:max_len-2]+[self.end_token],pad_mask[:max_len]
        pads_ = [self.pad_token]*(max_len-2-len(tokens))
        pad_mask[max_len:] = [0]*len(pads_)
        return [self.start_token]+tokens+[self.end_token]+pads_,pad_mask
    @classmethod
    def left_shift(cls,tokens,pad_token):
        return jnp.concat([tokens[:,1:],jnp.expand_dims(jnp.repeat(jnp.array([pad_token]),tokens.shape[0]),axis=-1)],axis=-1)
    @classmethod
    def left_shift_mask(cls,padding):
        return jnp.concat([padding[:,1:],jnp.expand_dims(jnp.repeat(jnp.array([0]),padding.shape[0]),axis=-1)],axis=-1)
    def __call__(self,tokens):
        if type(tokens[0])==int:
            return self.add_pads(tokens)
        else:
            pad_map = list(map(self.add_pads,tokens))
            return {"token_ids":[sentence for sentence,_ in pad_map],
                    "padding_mask":[pad_mask for _,pad_mask in pad_map]}
        
        
                
        

In [119]:
import pandas as pd


In [120]:
data = pd.read_csv("language_translation_data/eng_-french.csv")

In [122]:
en_text = data["English words/sentences"].to_list()

In [126]:
fr_text = data["French words/sentences"].to_list()

In [129]:
en_tokenizer = Tokenizer("en_token")

In [131]:
fr_tokenizer = Tokenizer("fr_token")

In [133]:
padding_en = Padder(en_tokenizer,max_len=64)
padding_fr = Padder(fr_tokenizer,max_len=64)
en_tokens = en_tokenizer(en_text)
fr_tokens = fr_tokenizer(fr_text)

In [134]:
padded_text_en = padding_en(en_tokens)

In [135]:
padded_text_fr = padding_fr(fr_tokens)

In [281]:
class Schedular:
    """
    A class for learning rate scheduling in optimization.

    This class adjusts the learning rate of an optimizer according to a specified
    scheduling strategy. It uses a simple learning rate scheduler based on the
    model dimensionality, the current training step, and a warmup step parameter.

    Attributes:
        optimizer (Optimizer): The optimizer instance whose learning rate will be adjusted.

    Args:
        optimizer (Optimizer): An instance of an optimizer whose learning rate needs to be adjusted.

    Methods:
        update(d_model: int, step: int, warmup_steps: int) -> float:
            Updates the learning rate of the optimizer based on the current step and warmup steps.

    Example:
        >>> optimizer = AdamW(lr=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, lambda_=0.01)
        >>> schedular = Schedular(optimizer)
        >>> lr = schedular.update(d_model=512, step=1000, warmup_steps=4000)
        >>> print(f"Updated Learning Rate: {lr}")
    """
    def __init__(self,optimizer):
        self.optimizer = optimizer
    def update(self,d_model,step,warmup_steps):
        self.optimizer.lr = (d_model**-0.5)*min(step**-0.5,step*warmup_steps**-1.5)
        return self.optimizer.lr

In [None]:
Trans = Transformer(512,2048,8,1,1000)
en_tokens,en_mask = padded_text_en['token_ids'][:1024],padded_text_en['padding_mask'][:1024]
fr_tokens,fr_mask = padded_text_fr['token_ids'][:1024],padded_text_fr['padding_mask'][:1024]
opt = AdamW(1e-2,9e-1,999e-3,1e-9,1e-2)
trainer = Trainer(Trans,CategoricalCrossEntropy,opt)
trainer.train(({'token_ids':en_tokens,'padding_mask':en_mask},
                {'token_ids':fr_tokens,'padding_mask':fr_mask}),
               fr_tokens,500,batch_size=32)

Epoch  0 
Loss:
6.8926473
Epoch  1 
Loss:
6.8774495
Epoch  2 
Loss:
6.8621492
Epoch  3 
Loss:
6.8467507
Epoch  4 
Loss:
6.8311764
Epoch  5 
Loss:
6.8154383
Epoch  6 
Loss:
6.7997185
Epoch  7 
Loss:
6.7837515
Epoch  8 
Loss:
6.7678725
Epoch  9 
Loss:
6.7518153
Epoch  10 
Loss:
6.7358557
Epoch  11 
Loss:
6.7200174
Epoch  12 
Loss:
6.7041483
Epoch  13 
Loss:
6.6878127
Epoch  14 
Loss:
6.6718594
Epoch  15 
Loss:
6.6560326
Epoch  16 
Loss:
6.6401478
Epoch  17 
Loss:
6.6241474
Epoch  18 
Loss:
6.6085376
Epoch  19 
Loss:
6.5932787
Epoch  20 
Loss:
6.5780363
Epoch  21 
Loss:
6.5614834
Epoch  22 
Loss:
6.5458536
Epoch  23 
Loss:
6.5306344
Epoch  24 
Loss:
6.5144577
Epoch  25 
Loss:
6.5001183
Epoch  26 
Loss:
6.4854458
Epoch  27 
Loss:
6.4694576
Epoch  28 
Loss:
6.4542513
Epoch  29 
Loss:
6.4409345
Epoch  30 
Loss:
6.4236565
Epoch  31 
Loss:
6.4095564
Epoch  32 
Loss:
6.3970894
Epoch  33 
Loss:
6.3811674
Epoch  34 
Loss:
6.3659296
Epoch  35 
Loss:
6.3526234
Epoch  36 
Loss:
6.3392717
Epoch  37 


In [53]:
class Translator:
    """
    A class for translating text from one language to another using a neural translation model.

    This class takes in a translation model and tokenizers for both source and target languages,
    along with padding utilities. It performs the translation by tokenizing the input text, 
    padding the tokens, generating predictions from the model, and detokenizing the predicted tokens.

    Attributes:
        model (Transformer): The translation model used for generating predictions.
        en_tokenizer (Tokenizer): Tokenizer for the source language (e.g., English).
        fr_tokenizer (Tokenizer): Tokenizer for the target language (e.g., French).
        en_padder (Padder): Padding utility for the source language tokens.
        fr_padder (Padder): Padding utility for the target language tokens.

    Args:
        model (Transformer): A model instance used for translation.
        en_tokenizer (Tokenizer): Tokenizer instance for the source language.
        fr_tokenizer (Tokenizer): Tokenizer instance for the target language.
        en_padder (Padder): Padding utility for the source language.
        fr_padder (Padder): Padding utility for the target language.

    Methods:
        __call__(text: str) -> None:
            Translates the given text from the source language to the target language and prints the result.

    Example:
        >>> en_tokenizer = Tokenizer("en_model")
        >>> fr_tokenizer = Tokenizer("fr_model")
        >>> en_padder = Padder(en_tokenizer, max_len=128)
        >>> fr_padder = Padder(fr_tokenizer, max_len=128)
        >>> model = Transformer(d_model=512, d_ff=2048, num_heads=8, num_layers=6, n_vocab=30000)
        >>> translator = Translator(model, en_tokenizer, fr_tokenizer, en_padder, fr_padder)
        >>> translator("Hello, how are you?")
        Translates "Hello, how are you?" to the target language and prints the result.
    """
    def __init__(self,model,en_tokenizer,fr_tokenizer,en_padder,fr_padder):
        self.model = model
        self.en_tokenizer = en_tokenizer
        self.fr_tokenizer = fr_tokenizer
        self.en_padder = en_padder
        self.fr_padder = fr_padder
    def __call__(self,text):
        en_tokens = self.en_tokenizer([text])
        #print(en_tokens)
        padded_text_en = self.en_padder(en_tokens)
        print(padded_text_en)
        #fr_tokens = [[3]]
        count = 0
        pred_token = 0
        fr_text = "A"
        while pred_token!=4:
            fr_tokens = self.fr_tokenizer([fr_text])
            padded_text_fr = self.fr_padder(fr_tokens)
            #print(padded_text_fr['token_ids'])
            index=padded_text_fr['token_ids'][0].index(4)
            padded_text_fr['token_ids'][0][index] = 5
            padded_text_fr['padding_mask'][0][index] = 0
            print(padded_text_fr)
            predicted_tokens = jnp.argmax(self.model({'token_ids':padded_text_en['token_ids'],'padding_mask':padded_text_en['padding_mask']},
                   {'token_ids':padded_text_fr['token_ids'],'padding_mask':padded_text_fr['padding_mask']}),axis=-1)
            #print(predicted_tokens)
            pred_token = predicted_tokens[0][count]
            fr_text+=self.fr_tokenizer.detokenize(pred_token.tolist())
            count+=1
        #print(predicted_tokens)
        print([self.fr_tokenizer.detokenize(tokens.tolist()) for tokens in predicted_tokens])
        print(fr_text)

In [73]:
import jaxlib
def convert_weights(params):
    params = copy.deepcopy(params)
    for key in params:
        if type(params[key])==jaxlib.xla_extension.ArrayImpl:
            params[key]=params[key].tolist()
        else:
            params[key] = convert_weights(params[key])
    return params
            

In [74]:
model_weights = convert_weights(Trans.params)

In [75]:
weight_string = json.dumps(model_weights)

In [76]:
with open("weights-6-512.json","w") as file:
    file.write(weight_string)

In [552]:
def convert_weights_jax(params):
    params = copy.deepcopy(params)
    for key in params:
        if type(params[key])==list:
            params[key]=jnp.array(params[key])
        else:
            params[key] = convert_weights_jax(params[key])
    return params

In [553]:
with open("weights.json","r") as file:
    weight_string = file.read()
    params = json.loads(weight_string)
    params = convert_weights_jax(params)

In [589]:
#params

In [302]:
# def clip_gradients(params):
#     params = copy.deepcopy(params)
#     for key in params:
#         if type(params[key])==jaxlib.xla_extension.ArrayImpl:
#             params[key] = jnp.clip(params[key],-1.0,1.0)
#             params[key] = jnp.nan_to_num(params[key])
#         else:
#             params[key] = clip_gradients(params[key])
#     return params