In [2]:
from jax import numpy as jnp
import jax
from jax import grad,vmap
from jax import random
import matplotlib.pyplot as plt
from jax.tree_util import register_pytree_node_class
import numpy as np
from jax import lax as jlax
from jax.tree_util import register_pytree_node_class
import json
#import copyself.components
import jaxlib

In [290]:
@register_pytree_node_class
class Parameter:
    def __init__(self,name,value):
        self.name = name
        self.value = value
        self.shape = value.shape
    def __sub__(self,param):
        if isinstance(param,Parameter):
            return Parameter(self.name,self.value-param.value)
        raise TypeError(f"unsupported operand type(s) for -: {type(param)} and 'Parameter'")
    def __add__(self,other):
        if isinstance(other,Parameter):
            return Parameter(self.name,self.value+other.value)
        if isinstance(other,float):
            return Parameter(self.name,self.value+other)
        raise TypeError(f"unsupported operand type(s) for +: {type(other)} and 'Parameter'")
    def __mul__(self,factor):
        if isinstance(factor,float):
            return Parameter(self.name,factor*self.value)
        raise TypeError(f'Cannot multiply a Parameter with {type(factor)}')
    def __rmul__(self,factor):
        if isinstance(factor,float):
            return Parameter(self.name,factor*self.value)
        raise TypeError(f'Cannot multiply a Parameter with {type(factor)}')
    def __pow__(self,factor):
        return Parameter(self.name,self.value**factor)
    def __truediv__(self,other):
        if isinstance(other,Parameter):
            return Parameter(self.name,self.value/other.value)
        if isinstance(other,float):
            return Parameter(self.name,self.value/other)
        raise TypeError(f'Cannot divide a Parameter with {type(factor)}')
    def tree_flatten(self):
        children = (self.value,)
        aux_data = (self.name,)
        return (children, aux_data)
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*aux_data,*children)

In [291]:
@register_pytree_node_class
class LinearParams:
    def __init__(self,name,weights):
        self.name = name
        if isinstance(weights,Parameter):
            self.weights = weights
        else:
            self.weights = Parameter("W",weights)
    def __sub__(self,other):
        if isinstance(other,LinearParams):
            return LinearParams(self.name,self.weights-other.weights)
        raise TypeError(f"unsupported operand type(s) for -: {type(other)} and 'LinearParams'")
    def __add__(self,other):
        if isinstance(other,LinearParams) :
            return LinearParams(self.name,self.weights+other.weights)
        if isinstance(other,float) :
            return LinearParams(self.name,self.weights+other)
        raise TypeError(f"unsupported operand type(s) for +: {type(other)} and 'LinearParams'")
    def __mul__(self,other):
        if isinstance(other,float):
            return LinearParams(self.name,self.weights*other)
        raise TypeError(f"Cannot multiply a 'LinearParams' with {type(other)}")
    def __rmul__(self,other):
        if isinstance(other,float):
            return LinearParams(self.name,self.weights*other)
        raise TypeError(f"Cannot multiply a 'LinearParams' with {type(other)}")
    def __truediv__(self,other):
        if isinstance(other,LinearParams) :
            return LinearParams(self.name,self.weights/other.weights)
        if isinstance(other,float):
            return LinearParams(self.name,self.weights/other)
        raise TypeError(f"Cannot divide a 'LinearParams' with {type(other)}")
    def __pow__(self,factor):
        return LinearParams(self.name,self.weights**factor)
    def tree_flatten(self):
        children = (self.weights,)
        aux_data = (self.name,)
        return (children, aux_data)
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*aux_data,*children)

In [292]:
@register_pytree_node_class
class FeedForwardParams:
    def __init__(self,name,weights,bias):
        self.name = name
        if isinstance(weights,Parameter):
            self.weights = weights
        else:
            self.weights = Parameter("W",weights)
        if isinstance(bias,Parameter):
            self.bias = bias
        else:
            self.bias = Parameter("bais",bias)
    def __sub__(self,other):
        if isinstance(other,FeedForwardParams) :
            return FeedForwardParams(self.name,self.weights-other.weights,self.bias-other.bias)
        raise TypeError(f"unsupported operand type(s) for -: {type(other)} and 'FeedForwardParams'")
    def __add__(self,other):
        if isinstance(other,FeedForwardParams):
            return FeedForwardParams(self.name,self.weights+other.weights,self.bias+other.bias)
        if isinstance(other,float):
            return FeedForwardParams(self.name,self.weights+other,self.bias+other)
        raise TypeError(f"unsupported operand type(s) for +: {type(other)} and 'FeedForwardParams'")
    def __mul__(self,other):
        if isinstance(other,float):
            return FeedForwardParams(self.name,self.weights*other,self.bias*other)
        raise TypeError(f"Cannot multiply a 'FeedForwardParams' with {type(other)}")
    def __rmul__(self,other):
        if isinstance(other,float):
            return FeedForwardParams(self.name,self.weights*other,self.bias*other)
        raise TypeError(f"Cannot multiply a 'FeedForwardParams' with {type(other)}")
    def __truediv__(self,other):
        if isinstance(other,FeedForwardParams):
            return FeedForwardParams(self.name,self.weights/other.weights,self.bias/other.bias)
        if isinstance(other,float):
            return FeedForwardParams(self.name,self.weights/other,self.bias/other)
        raise TypeError(f"Cannot divide a 'FeedForwardParams' with {type(other)}")
    def __pow__(self,factor):
        return FeedForwardParams(self.name,self.weights**factor,self.bias**factor)
    def tree_flatten(self):
        children = (self.weights,self.bias,)
        aux_data = (self.name,)
        return (children, aux_data)
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*aux_data,*children)
    

In [293]:
@register_pytree_node_class
class AttentionParams:
    def __init__(self,name,w_k,w_q,w_v):
        self.name = name
        if isinstance(w_k,Parameter):
            self.w_k = w_k
        else:
            self.w_k = Parameter("w_k",w_k)
        if isinstance(w_q,Parameter):
            self.w_q = w_q
        else:
            self.w_q = Parameter("w_q",w_q)
        if isinstance(w_v,Parameter):
            self.w_v = w_v
        else:
            self.w_v = Parameter("w_v",w_v)
        
    def __sub__(self,other):
        if isinstance(other,AttentionParams):
            return AttentionParams(self.name,self.w_k-other.w_k,self.w_q-other.w_q,self.w_v-other.w_v)
        raise TypeError(f"unsupported operand type(s) for -: {type(other)} and 'AttentionParams'")
    def __add__(self,other):
        if isinstance(other,AttentionParams):
            return AttentionParams(self.name,self.w_k+other.w_k,self.w_q+other.w_q,self.w_v+other.w_v)
        if isinstance(other,float):
            return AttentionParams(self.name,self.w_k+other,self.w_q+other,self.w_v+other)
        raise TypeError(f"unsupported operand type(s) for +: {type(other)} and 'AttentionParams'")
    def __mul__(self,other):
        if isinstance(other,float):
            return AttentionParams(self.name,self.w_k*other,self.w_q*other,self.w_v*other)
        raise TypeError(f"Cannot multiply a 'AttentionParams' with {type(other)}")
    def __truediv__(self,other):
        if isinstance(other,AttentionParams):
            return AttentionParams(self.name,self.w_k/other.w_k,self.w_q/other.w_q,self.w_v/other.w_v)
        if isinstance(other,float):
            return AttentionParams(self.name,self.w_k/other,self.w_q/other,self.w_v/other)
        raise TypeError(f"Cannot divide a 'AttentionParams' with {type(other)}")
    def __rmul__(self,other):
        if isinstance(other,float):
            return AttentionParams(self.name,self.w_k*other,self.w_q*other,self.w_v*other)
        raise TypeError(f"Cannot multiply a 'AttentionParams' with {type(other)}")
    def __pow__(self,factor):
        return AttentionParams(self.name,self.w_k**factor,self.w_q**factor,self.w_v**factor)
    def tree_flatten(self):
        children = (self.w_k,self.w_q,self.w_v,)
        aux_data = (self.name,)
        return (children, aux_data)
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*aux_data,*children)

In [294]:
@register_pytree_node_class
class MultiHeadAttentionParams:
    def __init__(self,name,weights,heads:list[AttentionParams]):
        self.name = name
        if isinstance(weights,Parameter):
            self.weights = weights
        else:
            self.weights = Parameter("Wo",weights)
        self.heads = heads
        self.num_heads = len(heads)
    def add(self,head1,head2):
        return head1+head2
    def subtract(self,head1,head2):
        return head1-head2
    def multiply(self,val,head):
        return val*head
    def divide(self,head,val):
        return head/val
    def pow(self,val,head):
        return head**val
    def __sub__(self,other):
        if isinstance(other,MultiHeadAttentionParams):
            heads = list(map(self.subtract,self.heads,other.heads))
            return MultiHeadAttentionParams(self.name,self.weights-other.weights,heads)
        raise TypeError(f"unsupported operand type(s) for -: {type(other)} and 'MultiHeadAttentionParams'")
    def __add__(self,other):
        if isinstance(other,MultiHeadAttentionParams) :
            heads = list(map(self.add,self.heads,other.heads))
            return MultiHeadAttentionParams(self.name,self.weights+other.weights,heads)
        if isinstance(other,float):
            heads = list(map(self.add,self.heads,[other]*self.num_heads))
            return MultiHeadAttentionParams(self.name,self.weights+other,heads)
        raise TypeError(f"unsupported operand type(s) for +: {type(other)} and 'MultiHeadAttentionParams'")
    def __mul__(self,other):
        if isinstance(other,float):
            heads = list(map(self.multiply,[other]*self.num_heads,self.heads))
            return MultiHeadAttentionParams(self.name,self.weights*other,heads)
        raise TypeError(f"Cannot multiply a 'MultiHeadAttentionParams' with {type(other)}")
    def __truediv__(self,other):
        if isinstance(other,MultiHeadAttentionParams) :
            heads = list(map(self.divide,self.heads,other.heads))
            return MultiHeadAttentionParams(self.name,self.weights/other.weights,heads)
        if isinstance(other,float):
            heads = list(map(self.divide,self.heads,[other]*self.num_heads))
            return MultiHeadAttentionParams(self.name,self.weights/other,heads)
        raise TypeError(f"Cannot multiply a 'MultiHeadAttentionParams' with {type(other)}")
    def __rmul__(self,other):
        if isinstance(other,float):
            heads = list(map(self.multiply,[other]*self.num_heads,self.heads))
            return MultiHeadAttentionParams(self.name,self.weights*other,heads)
        raise TypeError(f"Cannot multiply a 'MultiHeadAttentionParams' with {type(other)}")
    def __pow__(self,factor):
        heads = list(map(self.pow,[factor]*self.num_heads,self.heads))
        return MultiHeadAttentionParams(self.name,self.weights**factor,heads)
    def tree_flatten(self):
        children = (self.weights,self.heads,)
        aux_data = (self.name,)
        return (children, aux_data)
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*aux_data,*children)

In [295]:
@register_pytree_node_class
class ModuleParams:
    def __init__(self,name,components):
        self.name = name
        self.components = components
        self.num_comps = len(components)
    def multiply(self,val,comp):
        return val*comp
    def subtract(self,comp1,comp2):
        return comp1-comp2
    def add(self,comp1,comp2):
        return comp1+comp2
    def pow(self,val,comp):
        return comp**val
    def divide(self,comp,val):
        return comp/val
    def __sub__(self,other):
        if isinstance(other,ModuleParams):
            comps = list(map(self.subtract,self.components,other.components))
            return ModuleParams(self.name,comps)
        raise TypeError(f"unsupported operand type(s) for -: {type(other)} and 'ModuleParams'")
    def __add__(self,other):
        if isinstance(other,ModuleParams) :
            comps = list(map(self.add,self.components,other.components))
            return ModuleParams(self.name,comps)
        if isinstance(other,float):
            comps = list(map(self.add,self.components,[other]*self.num_comps))
            return ModuleParams(self.name,comps)
        raise TypeError(f"unsupported operand type(s) for +: {type(other)} and 'ModuleParams'")
    def __mul__(self,other):
        if isinstance(other,float):
            comps = list(map(self.multiply,[other]*self.num_comps,self.components))
            return ModuleParams(self.name,comps)
        raise TypeError(f"Cannot multiply a 'ModuleParams' with {type(other)}")
    def __truediv__(self,other):
        if isinstance(other,ModuleParams) :
            comps = list(map(self.divide,self.components,other.components))
            return ModuleParams(self.name,comps)
        if isinstance(other,float):
            comps = list(map(self.divide,self.components,[other]*self.num_comps))
            return ModuleParams(self.name,comps)
        raise TypeError(f"Cannot divide a 'ModuleParams' with {type(other)}")
    def __rmul__(self,other):
        if isinstance(other,float):
            comps = list(map(self.multiply,[other]*self.num_comps,self.components))
            return ModuleParams(self.name,comps)
        raise TypeError(f"Cannot multiply a 'ModuleParams' with {type(other)}")
    def __pow__(self,factor):
        comps = list(map(self.pow,[factor]*self.num_comps,self.components))
        return ModuleParams(self.name,comps)
    def tree_flatten(self):
        children = (self.components,)
        aux_data = (self.name,)
        return (children, aux_data)
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*aux_data,*children)

In [276]:
import random as rnd
class Dropout:
    def __init__(self,dropout_p,seed=0):
        self.dropout_p = 0.2
        self.seed = seed
    def predict(self,x):
        _,key_ = random.split(random.key(rnd.randint(0,1000)))
        mask_ = random.bernoulli(key_,1-self.dropout_p,shape=x.shape)
        dropout_out = mask_*x
        scale = 1/(1-self.dropout_p)
        return dropout_out*scale
    def batched_predict(self,x):
        predictor = vmap(self.predict,in_axes=(0))
        return predictor(x)
    def __call__(self,x):
        if len(x.shape)>1:
            return self.batched_predict(x)
        return self.predict(x)
        


In [277]:
@register_pytree_node_class
class LinearLayer:
    @classmethod
    def initiate_params(cls,name,in_units,out_units,key,scale=1e-2):
        w_key,_= random.split(key,2)
        initializer = jax.nn.initializers.he_normal()
        params = {}
        #params["W"] = random.normal(w_key,shape = (n_vocab,embedding_dims),dtype=jnp.float32)*scale
        initializer = jax.nn.initializers.he_normal()
        #params["W"] = initializer(w_key,shape = (n_vocab,embedding_dims),dtype=jnp.float32)*scale
        params = LinearParams(name,initializer(w_key,shape = (in_units,out_units),dtype=jnp.float32)*scale)
        return params
    def __init__(self,name,in_units,out_units,params=None):
        self.in_units = in_units
        self.out_units = out_units
        self.params = params
        self.key = random.key(210)
        if params==None:
            self.params = LinearLayer.initiate_params(name,self.n_vocab,self.embedding_dims,self.key)
    def predict(self,x):
        x = jnp.matmul(x,self.params.weights.value)
        return x
    def batched_predict(self,x):
        predictor = vmap(self.predict,in_axes=[0])
        return predictor(x)
    def __call__(self,x):
        if len(x.shape)>2:
            return self.batched_predict(x)
        return self.predict(x)
        #print(x)
        
    def tree_flatten(self):
        children = (self.params,)
        aux_data = (self.in_units,self.out_units)
        return (children, aux_data)
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*aux_data,*children)

In [278]:
# class LoraLayer:
#     def __init__(self,
#                 r:int,
#                 lora_alpha: int,
#                 lora_dropout: int,
#                 merge_weights: bool ):
#         self.r = r
#         self.lora_alpha = lora_alpha
#         # Optional dropout
#         if lora_dropout > 0.:
#             self.lora_dropout = Dropout(p=lora_dropout)
#         else:
#             self.lora_dropout = lambda x: x
#         # Mark the weight as unmerged
#         self.merged = False
#         self.merge_weights = merge_weights

In [279]:
@register_pytree_node_class
class EmbeddingLayer:
    @classmethod
    def initiate_params(cls,name,n_vocab,embedding_dims,key,scale=1e-1):
        w_key,_= random.split(key,2)
        initializer = jax.nn.initializers.he_normal()
        params = {}
        #params["W"] = random.normal(w_key,shape = (n_vocab,embedding_dims),dtype=jnp.float32)*scale
        initializer = jax.nn.initializers.he_normal()
        #params["W"] = initializer(w_key,shape = (n_vocab,embedding_dims),dtype=jnp.float32)*scale
        params = LinearParams(name,initializer(w_key,shape = (n_vocab,embedding_dims),dtype=jnp.float32)*scale)
        return params   
    @classmethod
    def positional_enc(cls,emb_dims,seq_len):
        pos = jnp.arange(seq_len)[:, jnp.newaxis]
        pe = jnp.zeros((seq_len,emb_dims))
        div_terms = jnp.exp(jnp.arange(0, emb_dims, 2) * -(jnp.log(10000.0) / emb_dims))
        pe = pe.at[:, 0::2].set(jnp.sin(pos*div_terms))
        pe = pe.at[:, 1::2].set(jnp.cos(pos*div_terms))
        return pe
    def one_hot(self,x,max):
        return jnp.array(x[:,:,None]==jnp.arange(max),dtype=jnp.float32)
    def __init__(self,name,n_vocab,embedding_dims,params=None):
        self.n_vocab = n_vocab
        self.embedding_dims = embedding_dims
        self.params = params
        self.key = random.key(210)
        if params==None:
            self.params = EmbeddingLayer.initiate_params(name,self.n_vocab,self.embedding_dims,self.key)
    def predict(self,x,mask):
        seq_len = x.shape[-1]
        x = self.one_hot(x,self.n_vocab)
        x = jnp.matmul(x,self.params.weights.value)+EmbeddingLayer.positional_enc(self.embedding_dims,seq_len)
        mask = jnp.expand_dims(mask,axis=-1)
        x=x*mask+jnp.ones(shape=mask.shape)*1e-12
        return x
    def batched_predict(self,x,mask):
        predictor = vmap(self.predict,in_axes=[0,0])
        return predictor(x)
    def __call__(self,x,mask):
        if len(x.shape)>2:
            return self.batched_predict(x,mask)
        return self.predict(x,mask)
        #print(x)
        
    def tree_flatten(self):
        children = (self.params,)
        aux_data = (self.n_vocab,self.embedding_dims)
        return (children, aux_data)
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*aux_data,*children)

In [13]:
emb = EmbeddingLayer(name="Test_Emb",n_vocab=1000,embedding_dims=512)

In [14]:
x = random.randint(random.key(21),shape=(32,128),minval=0,maxval=1000)

In [15]:
mask = jnp.ones(shape=(32,128))

In [16]:
# jnp.matmul(emb.params.weights.value

In [17]:
emb(x,mask)[:,12]

Array([[-5.3115553e-01,  8.4907073e-01, -8.3078086e-01, ...,
         9.9777108e-01,  4.4762911e-03,  9.9656099e-01],
       [-5.3290927e-01,  8.4754169e-01, -8.2782418e-01, ...,
         1.0035292e+00,  1.0806098e-02,  1.0036786e+00],
       [-5.3966665e-01,  8.4262085e-01, -8.4065729e-01, ...,
         9.9260503e-01, -1.3450127e-03,  9.9290007e-01],
       ...,
       [-5.3454167e-01,  8.4251159e-01, -8.3225435e-01, ...,
         9.9481994e-01,  3.6581811e-03,  9.9596053e-01],
       [-5.3558612e-01,  8.3551776e-01, -8.3124691e-01, ...,
         9.9556798e-01, -4.6265158e-03,  1.0012770e+00],
       [-5.3675395e-01,  8.3664453e-01, -8.2822013e-01, ...,
         1.0007498e+00, -1.2522831e-04,  9.9827188e-01]], dtype=float32)

In [18]:
# class LoraEmbedding(EmbeddingLayer,LoraLayer):
#      def __init__(self,
#                 r:int,
#                 lora_alpha: int,
#                 lora_dropout: int,
#                 merge_weights: bool ,
#                 n_vocab:int,
#                 embedding_dims:int,
#                 params = None):
#          self.key = random.key(5)
#          EmbeddingLayer.__init__(n_vocab,embedding_dims,params)
#          LoraLayer.__init__(r,lora_alpha,lora_dropout,merge_weights)
#          self.A =  random.normal(self.key,shape=(r,n_embeddings)
#          self.B = jnp.zeros(shape = (n_vocab,r))
#     def __call__(self,x):
#         x = self.one_hot(x,self.n_vocab)
#         W = jnp.dot(x,self.params)
#         delta_W = jnp.dot(B,A)
        
#         return  + 

In [19]:
# class TestNetwork:
#     def 

In [20]:
def relu(x):
    return jnp.maximum(0,x)

In [21]:
@register_pytree_node_class
class FeedForward:
    def initiate_params(name,input_shape,units,key,scale=1e-4):
        w_key,b_key = random.split(key,2)
        params = {}
        #params["W"] = random.normal(w_key,shape = (input_shape,units),dtype=jnp.float32)*scale
        #params["b"] = random.normal(b_key,shape = (units,))*scale
        
        initializer = jax.nn.initializers.he_normal()
        #params["W"] = initializer(w_key,shape = (input_shape,units),dtype=jnp.float32)*scale
        params = FeedForwardParams(name,
                                   weights = initializer(w_key,shape = (input_shape,units),dtype=jnp.float32)*scale,
                                   bias = random.normal(b_key,shape = (units,))*scale)
        #params["b"] = initializer(b_key,shape = (units,))*scale
        return params
    def __init__(self,name,d_model,units,activation=lambda x:x,params=None):
        self.activation = activation
        self.units = units
        self.key = random.key(210)
        self.d_model = d_model
        if params == None:
            self.params = FeedForward.initiate_params(name,d_model,self.units,self.key)
        else:
            self.params = params
    def predict(self,input):
        return self.activation(jnp.matmul(input,self.params.weights.value)+self.params.bias.value)
    def batched_predict(self,inputs):
        predictor = vmap(self.predict,in_axes = (0))
        return predictor(inputs)
    def __call__(self,input):
        if len(input.shape)>1:
            return self.batched_predict(input)
        return self.predict(input)
    def tree_flatten(self):
        children = (self.params,)
        aux_data = (self.d_model,self.units,self.activation)
        return (children, aux_data)
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*aux_data,*children)

In [22]:
ff = FeedForward(name="ff1",d_model=512,units=2048,activation=relu)

In [23]:
ff(emb(x,mask))[:,1]

Array([[1.5880790e-04, 3.9633096e-04, 2.1654676e-04, ..., 2.3233908e-05,
        0.0000000e+00, 3.7421738e-05],
       [1.5861183e-04, 3.9658550e-04, 2.1635339e-04, ..., 2.2884618e-05,
        0.0000000e+00, 3.6784462e-05],
       [1.5878624e-04, 3.9760862e-04, 2.1506406e-04, ..., 2.2920343e-05,
        0.0000000e+00, 3.6622201e-05],
       ...,
       [1.5958722e-04, 3.9776997e-04, 2.1453734e-04, ..., 2.2507738e-05,
        0.0000000e+00, 3.7097256e-05],
       [1.5870271e-04, 3.9681490e-04, 2.1639798e-04, ..., 2.2913002e-05,
        0.0000000e+00, 3.8226081e-05],
       [1.5863587e-04, 3.9581870e-04, 2.1496273e-04, ..., 2.3156172e-05,
        0.0000000e+00, 3.7206482e-05]], dtype=float32)

In [24]:
@register_pytree_node_class
class AttentionHead:
    def initiate_params(name,input_shape,units,key,scale=1e-2):
        q_key,k_key,v_key = random.split(key,3)
        
        params = {}
        # params["Wq"] = random.normal(q_key,shape = (input_shape,units),dtype=jnp.float32)*scale
        # params["Wk"] = random.normal(k_key,shape = (input_shape,units),dtype=jnp.float32)*scale
        # params['Wv'] = random.normal(v_key,shape = (input_shape,units),dtype=jnp.float32)*scale
        initializer = jax.nn.initializers.he_normal()
        # params["Wq"] = initializer(q_key,shape = (input_shape,units),dtype=jnp.float32)*scale
        # params["Wk"] = initializer(k_key,shape = (input_shape,units),dtype=jnp.float32)*scale
        # params['Wv'] = initializer(v_key,shape = (input_shape,units),dtype=jnp.float32)*scale
        params = AttentionParams(name=name,
                                 w_q = initializer(q_key,shape = (input_shape,units),dtype=jnp.float32)*scale,
                                 w_k = initializer(k_key,shape = (input_shape,units),dtype=jnp.float32)*scale,
                                 w_v = initializer(v_key,shape = (input_shape,units),dtype=jnp.float32)*scale
                                )
                                 
        return params
    def __init__(self,name,d,d_model,params=None):
        self.d = d
        self.d_model = d_model 
        self.key = random.key(210)
        self.params = params
        if params ==None:
            self.params = AttentionHead.initiate_params(name,self.d_model,self.d,self.key)
    def predict(self,x_q,x_k,x_v,mask,decoder=False):
        query = jnp.matmul(x_q,self.params.w_q.value)
        key = jnp.matmul(x_k,self.params.w_k.value)
        value = jnp.matmul(x_v,self.params.w_v.value)
        #print("Attenion Shapes:",query.shape,key.shape,value.shape)
        attn_scores = jnp.matmul(query,key.T)/jnp.sqrt(self.d)
        if mask != None:
            mask = jnp.expand_dims(mask,axis=0)
            #print(mask*mask.T)
            attn_scores = attn_scores*(mask*mask.T) +(mask*mask.T!=1)*(-1e-20)
        #print(attn_scores)
        #print(attn_scores.shape)
        softmaxed_attn = jax.nn.softmax(attn_scores)
        if mask != None:
            softmaxed_attn = softmaxed_attn*(mask*mask.T) +(mask*mask.T!=1)*(1e-32)
        #softmaxed_attn = jnp.nan_to_num(softmaxed_attn)
        #print(softmaxed_attn)
        if decoder:
            softmaxed_attn = softmaxed_attn*jnp.triu(jnp.ones(attn_scores.shape))
        #print(softmaxed_attn)
        #print("Value Matrix:",value.shape)
        return jnp.matmul(softmaxed_attn,value)
    def batched_predict(self,x_q,x_k,x_v,mask,decoder=False):
        predictor = vmap(self.predict,in_axes = (0,0,0,0,None))
        return predictor(x_q,x_k,x_v,mask,decoder)
    def __call__(self,x_q,x_k,x_v,mask,decoder=False):
        if len(x_q.shape)>1:
            #print(x_q.shape)
            return self.batched_predict(x_q,x_k,x_v,mask,decoder)
        return self.predict(x_q,x_k,x_v,mask,decoder)
    def tree_flatten(self):
        children = (self.params,)
        aux_data = (self.d,self.d_model)
        return (children, aux_data)
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*aux_data,*children)

In [25]:
@register_pytree_node_class
class MultiHeadAttention:
    def initiate_params(name,num_heads,input_shape,units,key,scale=1e-3):
        o_key,*h_key = random.split(key,num_heads+1)
        #print(o_key,h_key)
        params = {}
        #params['Wo'] = random.normal(o_key,shape = (input_shape,units),dtype=jnp.float32)*scale
        initializer = jax.nn.initializers.he_normal()
        #params['Wo'] = initializer(o_key,shape = (input_shape,units),dtype=jnp.float32)*scale
        params = MultiHeadAttentionParams(name,
                                          weights = initializer(o_key,shape = (input_shape,units),dtype=jnp.float32)*scale,
                                          heads = [AttentionHead.initiate_params(f"H{i}",
                                                                                 input_shape,
                                                                                 input_shape//num_heads,
                                                                                 h_key[i]) 
                                                   for i in range(num_heads)
                                                  ]
                                         )
        return params
    def __init__(self,name,h,d_model,params=None):
        self.h = h
        self.d_model = d_model
        self.key = random.key(210)
        self.d = d_model//h
        self.params = params
        if params ==None:
            self.params = MultiHeadAttention.initiate_params(name,self.h,self.d_model,self.d_model,self.key)
        
        if self.d_model%self.h!=0:
            raise "D_model not divisible by number of heads"
        self.attentionHeads = [AttentionHead(f"H{i}",
                                             self.d,
                                             self.d_model,
                                             self.params.heads[i]) 
                               for i in range(self.h)]
        
    #def predict(self,x_q,x_k,x_v,mask=None,decoder=False):
        #return jnp.matmul(self.params['Wo'],jnp.concat([head.predict(x_q_i,x_k_i,x_v_i,mask,decoder) for head,x_q_i,x_k_i,x_v_i in zip(self.attentionHeads,[x_q]*8,[x_k]*8,[x_v]*8)]))
    def calc_attentions(self,x_q,x_k,x_v,mask=None,decoder=False):
        concat_attn = jnp.concat([head.batched_predict(x_q_i,x_k_i,x_v_i,mask,decoder) 
                                  for head,x_q_i,x_k_i,x_v_i in 
                                  zip(self.attentionHeads,[x_q]*8,[x_k]*8,[x_v]*8)],
                                 axis=-1)
        return concat_attn
    def predict(self,attns):
        return jnp.matmul(attns,self.params.weights.value)
    def batched_predict(self,attns):
        predictor = vmap(self.predict,in_axes = (0))
        return predictor(attns)
    def __call__(self,x_q,x_k,x_v,mask=None,decoder=False):
        attns = self.calc_attentions(x_q,x_k,x_v,mask,decoder)
        if len(x_q.shape)>1: 
            return self.batched_predict(attns)
        return self.predict(attns)
    def tree_flatten(self):
        children = (self.params,)
        aux_data = (self.h,self.d_model)
        return (children, aux_data)
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*aux_data,*children)

In [26]:
@register_pytree_node_class
class EncoderLayer:
    def layer_normalization(output,epsilon=1e-9):
        H = output.shape[-1]
        mean = jnp.expand_dims(output.mean(axis=-1),axis=-1)
        std = jnp.expand_dims(output.std(axis=-1),axis=-1)
        output = (output - mean)/(std+epsilon)
        return output
    def __init__(self,name,d_model,d_ff,num_heads,params=None):
        self.d_model = d_model
        self.d_ff = d_ff
        self.num_heads = num_heads
        self.params = params
        self.key = random.key(210)
        if params==None:
            ff1_key,ff2_key,mha_key = random.split(self.key,3)
            self.params = ModuleParams(name,
                                       [MultiHeadAttention.initiate_params('mha',num_heads,d_model,d_model,mha_key),
                                        FeedForward.initiate_params('ff1',d_model,d_ff,ff1_key),
                                        FeedForward.initiate_params('ff2',d_ff,d_model,ff2_key)])
        
        self.ff1 = FeedForward("ff1",d_model,d_ff,params=self.params.components[1])
        self.ff2 = FeedForward("ff2",d_ff,d_model,params=self.params.components[2])
        #self.__name__ = f"EncoderLayer{num}"
        self.mha = MultiHeadAttention("mha",num_heads,d_model,self.params.components[0])
        self.dropout = Dropout(0.2)
       
        #print(self.params)
    def predict(self,input,mask):
        attentions = self.mha(input,input,input,mask)
        #print(mask)
        attentions = self.dropout(attentions)
        #print("Attentions")
        #print(attentions)
        x = EncoderLayer.layer_normalization(input+attentions)
        #print("x+attentions")
       # print(x)
        ff_ = self.ff2(self.ff1(x))
        ff_ = self.dropout(ff_)
        #print("x+ff_")
        #print(x+ff_)
        x = EncoderLayer.layer_normalization(x+ff_)
        #print(x)
        return x
    # def batched_predict(self,inputs,mask):
    #     predictor = vmap(self.predict,in_axes=(0,0))
    #     return predictor(inputs,mask)
    def __call__(self,inputs,mask):
        # if len(inputs.shape)>1:
        #     return self.batched_predict(inputs,mask)
        return self.predict(inputs,mask)
    def tree_flatten(self):
        children = (self.params,)
        aux_data = (self.d_model,self.d_ff,self.num_heads)
        return (children, aux_data)
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*aux_data,*children)
        

In [27]:
#EncoderLayer.layer_normalization(random.normal(random.key(210),shape=(64,128,512)))

In [28]:
class EncoderLayerParams:
    def __init__(self,name,d_model,d_ff,num_heads,key):
        ff1_key,ff2_key,mha_key = random.split(key,3)
        self.params = ModuleParams(name,
                                    [MultiHeadAttention.initiate_params('mha',num_heads,d_model,d_model,mha_key),
                                    FeedForward.initiate_params('ff1',d_model,d_ff,ff1_key),
                                    FeedForward.initiate_params('ff2',d_ff,d_model,ff2_key)])
        

class EncoderParams:
    def __init__(self,name,d_model,d_ff,num_heads,num_layers,key):
        keys = random.split(key,num_layers)
        self.params = ModuleParams(name,
                                   [EncoderLayerParams(f"L{i}",d_model,d_ff,num_heads,key_).params 
                                    for i,key_ in enumerate(keys)])
        
@register_pytree_node_class
class Encoder:
    def __init__(self,name,d_model,d_ff,num_heads,num_layers,params=None):
        #self.num= generate_number(num_layers)
        self.d_model=d_model
        self.d_ff=d_ff
        self.num_heads=num_heads
        self.num_layers=num_layers
        self.key = random.key(210)
        self.params = params
        if params==None:
            self.params = ModuleParams(name,
                                   [EncoderLayerParams(f"L{i}",d_model,d_ff,num_heads,key_).params 
                                    for i,key_ in enumerate(keys)])
        self.layers = [EncoderLayer(f"L{i}",d_model,d_ff,num_heads,self.params.components[i]) for i in range(num_layers)]
    def __call__(self,input,mask):
        x = input
        #print("Encoder Input Shape:",x.shape)
        for layer in self.layers:
            x = layer(x,mask)
        return x
    def tree_flatten(self):
        children = (self.params,)
        aux_data = (self.d_model,self.d_ff,self.num_heads,self.num_layers)
        return (children, aux_data)
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*aux_data,*children)

In [29]:
class DecoderLayerParams:
    def __init__(self,name,d_model,d_ff,num_heads,key):
        self.params = {}
        ff1_key,ff2_key,e_mha_key,d_mha_key = random.split(key,4)
        self.params = ModuleParams(name,
                                [MultiHeadAttention.initiate_params('d_mha',num_heads,d_model,d_model,d_mha_key),
                                 MultiHeadAttention.initiate_params('e_mha',num_heads,d_model,d_model,e_mha_key),
                                 FeedForward.initiate_params('ff1',d_model,d_ff,ff1_key),
                                 FeedForward.initiate_params('ff2',d_ff,d_model,ff2_key)])

@register_pytree_node_class        
class DecoderLayer:
    def layer_normalization(output,epsilon=1e-9):
        H = output.shape[-1]
        mean = jnp.expand_dims(output.mean(axis=-1),axis=-1)
        std = jnp.expand_dims(output.std(axis=-1),axis=-1)
        output = (output - mean)/(std+epsilon)
        return output
    def __init__(self,name,d_model,d_ff,num_heads,params=None):
        self.d_model = d_model
        self.num_heads = num_heads
        self.key = random.key(210)
        self.params = params
        if params ==None:
            ff1_key,ff2_key,e_mha_key,d_mha_key = random.split(self.key,4)
            self.params = ModuleParams(name,
                                    [MultiHeadAttention.initiate_params('d_mha',num_heads,d_model,d_model,d_mha_key),
                                     MultiHeadAttention.initiate_params('e_mha',num_heads,d_model,d_model,e_mha_key),
                                     FeedForward.initiate_params('ff1',d_model,d_ff,ff1_key),
                                     FeedForward.initiate_params('ff2',d_ff,d_model,ff2_key)])
        self.ff1 = FeedForward("ff1",d_model,d_ff,params = self.params.components[2])
        self.ff2 = FeedForward("ff2",d_ff,d_model,params = self.params.components[3])
        self.d_mha = MultiHeadAttention("d_mha",num_heads,d_model,params = self.params.components[0])
        self.e_mha = MultiHeadAttention("e_mha",num_heads,d_model,params = self.params.components[1])
        self.dropout = Dropout(0.2)
    def predict(self,input,encoder_output,mask):
        attentions = self.d_mha(input,input,input,mask,decoder=True)
        #attentions = self.dropout(attentions)
        x = DecoderLayer.layer_normalization(input+attentions)
        e_attentions = self.e_mha(x,encoder_output,encoder_output,mask)
        #e_attentions = self.dropout(e_attentions)
        x = DecoderLayer.layer_normalization(x+e_attentions)
        ff_ = self.ff2(self.ff1(x))
        ff_ = self.dropout(ff_)
        x = DecoderLayer.layer_normalization(x+ff_)
        return x
    # def batched_predict(self,inputs,encoder_output,mask):
    #     predictor = vmap(self.predict,in_axes=(0,0,0))
    #     return predictor(inputs,encoder_output,mask)
    def __call__(self,inputs,encoder_output,mask):
        # if len(inputs.shape)>1:
        #     return self.batched_predict(inputs,encoder_output)
        return self.predict(inputs,encoder_output,mask)
    def tree_flatten(self):
        children = (self.params,)
        aux_data = (self.d_model,self.d_ff,self.num_heads)
        return (children, aux_data)
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*aux_data,*children)
        

In [30]:
class DecoderParams:
    def __init__(self,name,d_model,d_ff,num_heads,num_layers,key):
        keys = random.split(key,num_layers)
        self.params = ModuleParams(name,
                                   [DecoderLayerParams(f"L{i}",d_model,d_ff,num_heads,key_).params 
                                    for i,key_ in enumerate(keys)])
@register_pytree_node_class
class Decoder:
    def __init__(self,name,d_model,d_ff,num_heads,num_layers,params=None):
        self.params = params
        self.keys = random.split(random.key(251),num_layers)
        self.d_model = d_model
        self.d_ff = d_ff
        self.num_heads = num_heads
        self.num_layers = num_layers
        if params==None:
            self.params = ModuleParams(name,
                                   [DecoderLayerParams(f"L{i}",d_model,d_ff,num_heads,key_).params 
                                    for i,key_ in enumerate(keys)])
        self.layers = [DecoderLayer(f"L{i}",d_model,d_ff,num_heads,self.params.components[i]) for i in range(num_layers)]
    def __call__(self,input,encoder_output,mask):
        x = input
        #print(mask)
        for layer in self.layers:
            x = layer(x,encoder_output,mask)
        return x
    def tree_flatten(self):
        children = (self.params,)
        aux_data = (self.d_model,self.d_ff,self.num_heads,self.num_layers)
        return (children, aux_data)
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*aux_data,*children)

In [31]:
?LinearLayer.initiate_params

[0;31mSignature:[0m [0mLinearLayer[0m[0;34m.[0m[0minitiate_params[0m[0;34m([0m[0mname[0m[0;34m,[0m [0min_units[0m[0;34m,[0m [0mout_units[0m[0;34m,[0m [0mkey[0m[0;34m,[0m [0mscale[0m[0;34m=[0m[0;36m0.01[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m <no docstring>
[0;31mFile:[0m      /tmp/ipykernel_1931/2505762081.py
[0;31mType:[0m      method

In [207]:
class TransformerParams:
    def __init__(self,d_model,d_ff,num_heads,num_layers,n_vocab,key):
        self.params = {}
        key_emb,key_emb_dec,key_e,key_d,key_l = random.split(key,5)
        #self.params["Embedding"] = EmbeddingLayer.initiate_params(n_vocab,d_model,key_emb)
        self.params   = ModuleParams("Transformer",
                                     [EmbeddingLayer.initiate_params("In_Embedding",n_vocab,d_model,key_emb),
                                     EmbeddingLayer.initiate_params("Out_Embedding",n_vocab,d_model,key_emb_dec),
                                     EncoderParams("Encoder",d_model,d_ff,num_heads,num_layers,key_e).params,
                                     DecoderParams("Decoder",d_model,d_ff,num_heads,num_layers,key_d).params,
                                     LinearLayer.initiate_params("Linear",d_model,n_vocab,key_l)])
        

@register_pytree_node_class
class Transformer:
    def __init__(self,d_model,d_ff,num_heads,num_layers,n_vocab,logits=False,params=None,seed=0):
        self.d_model = d_model
        self.d_ff=d_ff
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.n_vocab = n_vocab
        self.logits = logits
        self.params = params
        self.key = random.key(seed)
        if params == None:
            self.params = TransformerParams(d_model,d_ff,num_heads,num_layers,n_vocab,self.key).params
        #self.embedding = EmbeddingLayer(n_vocab,d_model,self.params["Embedding"])
        self.in_embedding = EmbeddingLayer("In_Embedding",n_vocab,d_model,self.params.components[0])
        self.out_embedding = EmbeddingLayer("Out_Embedding",n_vocab,d_model,self.params.components[1])
        self.encoder = Encoder("Encoder",d_model,d_ff,num_heads,num_layers,self.params.components[2])
        self.decoder = Decoder("Decoder",d_model,d_ff,num_heads,num_layers,self.params.components[3])
        self.linear = LinearLayer("Linear",d_model,n_vocab,params=self.params.components[4])
        self.logits = logits
    def update_params(self):
        self.in_embedding = EmbeddingLayer("In_Embedding",self.n_vocab,self.d_model,self.params.components[0])
        self.out_embedding = EmbeddingLayer("Out_Embedding",self.n_vocab,self.d_model,self.params.components[1])
        self.encoder = Encoder("Encoder",self.d_model,self.d_ff,self.num_heads,self.num_layers,self.params.components[2])
        self.decoder = Decoder("Decoder",self.d_model,self.d_ff,self.num_heads,self.num_layers,self.params.components[3])
        self.linear = LinearLayer("Linear",self.d_model,self.n_vocab,params=self.params.components[4])
    def __call__(self,inputs,outputs):
        input_tokens = jnp.array(inputs['token_ids'])
        input_mask = jnp.array(inputs['padding_mask'])
        output_tokens = jnp.array(outputs['token_ids'])
        output_mask = jnp.array(outputs['padding_mask'])
        input_tokens = Padder.left_shift(input_tokens,5)
        input_mask = Padder.left_shift_mask(input_mask)
        embs = self.in_embedding(input_tokens,input_mask)
        #print("In_Embeddings",embs)
        op_embs = self.out_embedding(output_tokens,output_mask)
        #print("Out_Embeddings",op_embs)
        if len(embs.shape)!=3:
            raise "Dimensions of the input must include (Batch,Token Sequence)"
        encoder_output = self.encoder(embs,input_mask)
        #print("Encoder Output:",encoder_output)
        decoder_output = self.decoder(op_embs,encoder_output,output_mask)
        #print("Decoder Output:",decoder_output)
        output = self.linear(decoder_output)
        #print("Linear Output:",output)
        if self.logits:
            return output
        return jax.nn.softmax(output,axis=-1)
    def tree_flatten(self):
        children = (self.params,)
        aux_data = (self.d_model,self.d_ff,self.num_heads,self.num_layers,self.n_vocab,self.logits)
        return (children, aux_data)
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*aux_data,*children)

In [33]:
# class Optimizer:
#     def __init__(self,lr,lambda_):
#         self.lr = lr
#         self.lambda_ = lambda_
#     def update_coder_params(self,t,params,grads):
#         for layer in params:
#             params[layer] = self.update_layer_params(t,
#                 params[layer],
#                 grads[layer]
#             )
#         return params
#     def update_layer_params(self,t,params,grads):
#         for type in params:
#             if 'ff' in type:
#                 params[type] = self.update_basic_params(t,
#                     params[type],
#                     grads[type]
#                 )
#             if 'mha' in type:
#                 params[type] = self.update_mha_params(t,
#                     params[type],
#                     grads[type]
#                 )
#         return params
#     def update_mha_params(self,t,params,grads):
#         params['Wo'] = self.update_params(t,params['Wo'],grads['Wo'])
#         for head in params:
#             if head == 'Wo':
#                 continue
#             for type in params[head]:
#                 #print(head,type)
#                 params[head][type] = self.update_params(t,params[head][type],grads[head][type])
#         return params
#     def update_basic_params(self,t,params,grads):
#         for type in params:
#             params[type]= self.update_params(t,params[type],grads[type])
#         return params
#     def update_params(self,t,params,grads):
#         params = params - self.lr*(grads)
#         return params

In [231]:
class Optimizer:
    def __init__(self,lr,lambda_):
        self.lr = lr
        self.lambda_ = lambda_
    def update_params(self,t,params,grads):
        params = params - self.lr*grads
        return params

In [280]:
class AdamW(Optimizer):
    '''Implementation of AdamW 
    Optimizer using JAX'''
    def __init__(self,lr,beta1,beta2,epsilon,lambda_):
        super().__init__(lr,lambda_)
        self.beta1 = beta1
        self.beta2 = beta2
        self.epsilon = epsilon
        self.m = None
        self.v = None
    def SetScheduleMultiplier(self,t):
        return 0.0001
    def update_params(self,t,params,grads):
        g = grads + self.lambda_*params
        if not self.m:
            self.m = (1-self.beta1)*g
            self.v = (1-self.beta2)*(g**2)
        else:
            self.m = self.beta1*self.m + (1-self.beta1)*g
            self.v = self.beta2*self.v + (1-self.beta2)*(g**2)
        m_hat = self.m/(1-(self.beta1)**t)
        v_hat = self.v/(1-(self.beta2)**t)
        eta = self.SetScheduleMultiplier(t)
        params = params - eta*((self.lr*m_hat)/((v_hat)**0.5+self.epsilon) + self.lambda_*params)
        return params

In [210]:
class Trainer:
    def __init__(self,model,loss,optimizer,schedular=None):
        self.model = model
        self.loss = loss
        #self.lr = lr
        self.optimizer = optimizer
        self.schedular = None
        if schedular!=None:
            self.schedular = schedular(optimizer)
    # def update(self,t,grads):
    #     self.model.params["Encoder"] = self.optimizer.update_coder_params(t,self.model.params["Encoder"],grads["Encoder"])
    #     self.model.params["Decoder"] = self.optimizer.update_coder_params(t,self.model.params["Decoder"],grads["Decoder"])
    #     self.model.params["Linear"] = self.optimizer.update_basic_params(t,self.model.params["Linear"],grads["Linear"])
    #     #self.model.params["Embedding"] = self.optimizer.update_basic_params(t,self.model.params["Embedding"],grads["Embedding"])
    #     self.model.params["In_Embedding"] = self.optimizer.update_basic_params(t,self.model.params["In_Embedding"],grads["In_Embedding"])
    #     self.model.params["Out_Embedding"] = self.optimizer.update_basic_params(t,self.model.params["Out_Embedding"],grads["Out_Embedding"])
    def update(self,t,grads):
        self.model.params = self.optimizer.update_params(t,self.model.params,grads)
        self.model.update_params()
    def train(self,x,y,epochs,batch_size=None):
        data_size = len(x[0]['token_ids'])
        if batch_size== None:
            batch_size = data_size
        # if learning_rate!=None:
        #     self.lr = learning_rate
        t = 0
        for epoch in range(epochs):
            curr = 0
            print("Epoch ",epoch,"\r",flush=True)
            print("Loss:")
            while curr<data_size:
                t+=1
                batch_x = (
                    {'token_ids':x[0]['token_ids'][curr:curr+batch_size],
                     'padding_mask':x[0]['padding_mask'][curr:curr+batch_size]},
                    {'token_ids':x[1]['token_ids'][curr:curr+batch_size],
                     'padding_mask':x[1]['padding_mask'][curr:curr+batch_size]}
                          )
                batch_y = y[curr:curr+batch_size]
                #print(Padder.left_shift(jnp.array(batch_y),5))
                #break
                batch_y = one_hot(Padder.left_shift(jnp.array(batch_y),5),self.model.n_vocab)
                curr = curr+batch_size
                grads = grad(self.loss)(self.model,batch_x,batch_y).params
                #grads = clip_gradients(grads)
                self.update(t,grads)
                if self.schedular!=None:
                    self.schedular.update(512,t,4000)
                    self.optimizer = self.schedular.optimizer
                    #print("Learning Rate:",self.optimizer.lr)
                print(self.loss(self.model,batch_x,batch_y),end="\r",flush=True)
            print(self.loss(self.model,batch_x,batch_y))

In [204]:
def CategoricalCrossEntropy(transformer,x,y):
    input_tokens = x[0]
    output_tokens = x[1]
    y_hat = transformer(input_tokens,output_tokens)
    labels = jnp.argmax(y,axis=-1)
    mask = labels!=5
    return jnp.mean(-(((y*jnp.log(y_hat)).sum(axis=-1))*mask).sum(axis=-1,keepdims=True)/mask.sum(axis=-1,keepdims=True))
    

In [91]:
#fr_tokens

In [92]:
# CategoricalCrossEntropy(Trans,({'token_ids':en_tokens[:32],'padding_mask':en_mask[:32]},
#                {'token_ids':fr_tokens[:32],'padding_mask':fr_mask[:32]}),
#               one_hot(Padder.left_shift(jnp.array(fr_tokens[:32]),5),1000))

In [93]:
#Trans({'token_ids':en_tokens[:32],'padding_mask':en_mask[:32]},{'token_ids':fr_tokens[:32],'padding_mask':fr_mask[:32]})

In [94]:

#test_loss = CategoricalCrossEntropy(Trans,
                                    # ({'token_ids':en_tokens[:32],'padding_mask':en_mask[:32]},
                                    # {'token_ids':fr_tokens[:32],'padding_mask':fr_mask[:32]}),
                                    # one_hot(Padder.left_shift(jnp.array(fr_tokens[:32]),5),1000))

In [95]:
#?Transformer

In [96]:
#test_loss

In [139]:
def one_hot(x,max):
        return jnp.array(x[:,:,None]==jnp.arange(max),dtype=jnp.float32)


In [98]:
# CategoricalCrossEntropy(Trans,({'token_ids':en_tokens,'padding_mask':en_mask},{'token_ids':fr_tokens,'padding_mask':fr_mask}),y_hot)

In [99]:
# x = random.randint(random.key(3),minval=0,maxval=3000,shape = (100,128))

In [100]:
# jnp.argmax(Trans(x,x),axis=-1)

In [101]:
# class SubwordTokenizer:
    # def __init__():

In [102]:
# class Tokenizer:
#     def __init__(self,vocab_size,):
        

In [103]:
# import unicodedata
# import re
# sentence = "I'm going to be some person, I couldn't be what I wanted to be but I'll be someone in 2025"
# # 
# re.sub(r"\s+"," ",re.sub(r"([^\'\w])",r" \1",sentence))

In [104]:

# def clean_text(sentence):
#     pattern = r'[\s]+'
#     sentence = re.sub(r"\s+"," ",re.sub(r"([^\'\w])",r" \1",sentence))
#     contractions = {"'ve":" have",
#                     "'ll":" will",
#                     "'m":" am",
#                     "'re":" are",
#                     "n't":" not",
#                     "'d":" had"}
#     # sentence = unicodedata.normalize("NFD",sentence)
#     words = re.split(pattern,sentence)
#     for contraction in contractions:
#         words = [word.replace(contraction,contractions[contraction]) if contraction in word else word for word in words]
#     words = re.split(pattern," ".join(words))
#     return " ".join(words)

In [105]:
# clean_text(sentence)

In [106]:
# sp.decode([1]+sp.encode(clean_text(sentence))+[2])

In [140]:
import sentencepiece as spm
import re
import unicodedata
class Tokenizer:
    import sentencepiece as spm
    def __init__(self,model_prefix):
        self.model_prefix = model_prefix
        self.model_file = self.model_prefix + ".model"
        try:
            self.sp = spm.SentencePieceProcessor(model_file = self.model_file)
        except:
            print("Model File Not Found. Tokenizer must be trained in order to make changes.")
    @classmethod
    def clean_text(cls,sentence):
        pattern = r'[\s]+'
        sentence = re.sub(r"\s+"," ",re.sub(r"([^\'\w])",r" \1",sentence))
        contractions = {"'ve":" have",
                    "'ll":" will",
                    "'m":" am",
                    "'re":" are",
                    "n't":" not",
                    "'d":" had"}
        sentence = unicodedata.normalize("NFD",sentence)
        words = re.split(pattern,sentence)
        for contraction in contractions:
            words = [word.replace(contraction,contractions[contraction]) if contraction in word else word for word in words]
        words = re.split(pattern," ".join(words))
        return " ".join(words)
    @classmethod
    def batched_clean_text(cls,x):
        return [text for text in map(cls.clean_text,x)]
        
    def train(self,file_name,vocab_size):
        spm.SentencePieceTrainer.train(input=file_name,
                                       model_prefix =self.model_prefix,
                                       vocab_size = vocab_size,
                                       control_symbols='<start>,<end>,<pad>')
        self.sp = spm.SentencePieceProcessor(model_file = self.model_file)
    def __call__(self,x,out_type=None):
        if type(x) == str:
            x = Tokenizer.clean_text(x)   
        else:
            x = Tokenizer.batched_clean_text(x)
        #print(x)
        return self.sp.encode(x,out_type)
    def detokenize(self,tokens):
        return self.sp.decode(tokens)

In [141]:
def left_shift(tokens):
    return tokens[:,1:]
def right_shift(tokens):
    return tokens[:,:-2]

In [142]:
class Padder:
    def __init__(self,tokenizer,max_len):
        self.sp = tokenizer.sp
        self.max_len = max_len
        self.pad_token = self.sp.piece_to_id("<pad>")
        self.start_token = self.sp.piece_to_id("<start>")
        self.end_token = self.sp.piece_to_id("<end>")
    def add_pads(self,tokens,max_len=None):
        if max_len==None:
            max_len = self.max_len
        pad_mask = [1]*(len(tokens)+2)
        if len(tokens)+2>=max_len:
            return [self.start_token]+tokens[:max_len-2]+[self.end_token],pad_mask[:max_len]
        pads_ = [self.pad_token]*(max_len-2-len(tokens))
        pad_mask[max_len:] = [0]*len(pads_)
        return [self.start_token]+tokens+[self.end_token]+pads_,pad_mask
    @classmethod
    def left_shift(cls,tokens,pad_token):
        return jnp.concat([tokens[:,1:],jnp.expand_dims(jnp.repeat(jnp.array([pad_token]),tokens.shape[0]),axis=-1)],axis=-1)
    @classmethod
    def left_shift_mask(cls,padding):
        return jnp.concat([padding[:,1:],jnp.expand_dims(jnp.repeat(jnp.array([0]),padding.shape[0]),axis=-1)],axis=-1)
    def __call__(self,tokens):
        if type(tokens[0])==int:
            return self.add_pads(tokens)
        else:
            pad_map = list(map(self.add_pads,tokens))
            return {"token_ids":[sentence for sentence,_ in pad_map],
                    "padding_mask":[pad_mask for _,pad_mask in pad_map]}
        
        
                
        

In [110]:
# tokenizer = Tokenizer("test")

In [111]:
# padding = Padder(tokenizer,32)

In [112]:
# ?spm.SentencePieceTrainer.train

In [113]:
#Trans(jnp.array(padding(tokenizer(sentences))['token_ids']),jnp.array(padding(tokenizer(sentences))['token_ids']))

In [114]:
# tokenizer.sp

In [115]:
# def n_grams(sentence):
#     sentences = []
#     words = [word for word in sentence.split(" ")]
#     for i in range(2,len(words)+1):
#         # sentences.append(" ".join(words[:i]))
#     return sentences

In [116]:
# sentences = n_grams(sentence)

In [117]:
# !kaggle d download devicharith/language-translation-englishfrench

In [118]:
# !unzip language-translation-englishfrench.zip -d "language_translation_data"

In [119]:
import pandas as pd


In [120]:
data = pd.read_csv("language_translation_data/eng_-french.csv")

In [121]:
#data

In [122]:
en_text = data["English words/sentences"].to_list()

In [123]:
# jax.nn.softmax(jnp.array([-jnp.inf,-jnp.inf,-jnp.inf,-jnp.inf]))

In [124]:
# en_text

In [125]:
# with open("en_text.txt","w") as fp:
#     for text in en_text:
#         print(text,file=fp)

In [126]:
fr_text = data["French words/sentences"].to_list()

In [127]:
# with open("fr_text.txt","w") as fp:
#     for text in fr_text:
#         print(text,file=fp)

In [128]:
# ?Tokenizer

In [129]:
en_tokenizer = Tokenizer("en_token")

In [130]:
#en_tokenizer.train("en_text.txt",vocab_size=1000)

In [131]:
fr_tokenizer = Tokenizer("fr_token")

In [132]:
#fr_tokenizer.train("fr_text.txt",vocab_size=1000)

In [133]:
padding_en = Padder(en_tokenizer,max_len=64)
padding_fr = Padder(fr_tokenizer,max_len=64)
en_tokens = en_tokenizer(en_text)
fr_tokens = fr_tokenizer(fr_text)

In [134]:
padded_text_en = padding_en(en_tokens)

In [135]:
padded_text_fr = padding_fr(fr_tokens)

In [136]:
#padded_text_en['token_ids']

In [84]:
#padded_text_en['token_ids']

In [85]:
#embs = Emb(en_tokens[:32])

In [86]:
4e-4

0.0004

In [87]:
#9e-3
step = 4000
warmup_steps = 4000

In [85]:
(512**-0.5)*min(step**-0.5,step*warmup_steps**-1.5)

0.0006987712429686843

In [281]:
class Schedular:
    def __init__(self,optimizer):
        self.optimizer = optimizer
    def update(self,d_model,step,warmup_steps):
        self.optimizer.lr = (d_model**-0.5)*min(step**-0.5,step*warmup_steps**-1.5)
        return self.optimizer.lr

In [298]:
999e-3

0.999

In [None]:
Trans = Transformer(512,2048,8,1,1000)
en_tokens,en_mask = padded_text_en['token_ids'][:1024],padded_text_en['padding_mask'][:1024]
fr_tokens,fr_mask = padded_text_fr['token_ids'][:1024],padded_text_fr['padding_mask'][:1024]
# #Trans.params = params
# Trans({'token_ids':en_tokens,'padding_mask':en_mask},{'token_ids':fr_tokens,'padding_mask':fr_mask})
#opt = AdamW(lr=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, lambda_=0.002)
opt = AdamW(1e-2,9e-1,999e-3,1e-9,1e-2)
trainer = Trainer(Trans,CategoricalCrossEntropy,opt)
trainer.train(({'token_ids':en_tokens,'padding_mask':en_mask},
                {'token_ids':fr_tokens,'padding_mask':fr_mask}),
               fr_tokens,500,batch_size=32)

Epoch  0 
Loss:
6.8926473
Epoch  1 
Loss:
6.8774495
Epoch  2 
Loss:
6.8621492
Epoch  3 
Loss:
6.8467507
Epoch  4 
Loss:
6.8311764
Epoch  5 
Loss:
6.8154383
Epoch  6 
Loss:
6.7997185
Epoch  7 
Loss:
6.7837515
Epoch  8 
Loss:
6.7678725
Epoch  9 
Loss:
6.7518153
Epoch  10 
Loss:
6.7358557
Epoch  11 
Loss:
6.7200174
Epoch  12 
Loss:
6.7041483
Epoch  13 
Loss:
6.6878127
Epoch  14 
Loss:
6.6718594
Epoch  15 
Loss:
6.6560326
Epoch  16 
Loss:
6.6401478
Epoch  17 
Loss:
6.6241474
Epoch  18 
Loss:
6.6085376
Epoch  19 
Loss:
6.5932787
Epoch  20 
Loss:
6.5780363
Epoch  21 
Loss:
6.5614834
Epoch  22 
Loss:
6.5458536
Epoch  23 
Loss:
6.5306344
Epoch  24 
Loss:
6.5144577
Epoch  25 
Loss:
6.5001183
Epoch  26 
Loss:
6.4854458
Epoch  27 
Loss:
6.4694576
Epoch  28 
Loss:
6.4542513
Epoch  29 
Loss:
6.4409345
Epoch  30 
Loss:
6.4236565
Epoch  31 
Loss:
6.4095564
Epoch  32 
Loss:
6.3970894
Epoch  33 
Loss:
6.3811674
Epoch  34 
Loss:
6.3659296
Epoch  35 
Loss:
6.3526234
Epoch  36 
Loss:
6.3392717
Epoch  37 


In [235]:
grads = grad(CategoricalCrossEntropy)(Trans,({'token_ids':en_tokens[:32],'padding_mask':en_mask[:32]},
                 {'token_ids':fr_tokens[:32],'padding_mask':fr_mask[:32]}),
                one_hot(Padder.left_shift(jnp.array(fr_tokens[:32]),5),1000)).params

In [102]:
Trans.params.components[1].name

'Out_Embedding'

In [297]:
Trans({'token_ids':padded_text_en['token_ids'][0:2],'padding_mask':padded_text_en['padding_mask'][0:2]},
      {'token_ids':padded_text_fr['token_ids'][0:2],'padding_mask':padded_text_fr['padding_mask'][0:2]})

In_Embeddings [[[ 9.3224767e-04  1.0073498e+00  3.6200152e-03 ...  1.0101473e+00
    1.2534949e-03  9.9859720e-01]
  [ 8.4026641e-01  5.4348910e-01  8.1956679e-01 ...  9.9350989e-01
    5.5423761e-03  9.9502671e-01]
  [ 9.1250330e-01 -4.1299888e-01  9.3438274e-01 ...  1.0023692e+00
   -2.5457989e-03  9.9514318e-01]
  ...
  [ 1.0000000e-12  1.0000000e-12  1.0000000e-12 ...  1.0000000e-12
    1.0000000e-12  1.0000000e-12]
  [ 1.0000000e-12  1.0000000e-12  1.0000000e-12 ...  1.0000000e-12
    1.0000000e-12  1.0000000e-12]
  [ 1.0000000e-12  1.0000000e-12  1.0000000e-12 ...  1.0000000e-12
    1.0000000e-12  1.0000000e-12]]

 [[-8.5690310e-03  1.0059215e+00 -1.4283760e-04 ...  9.9509794e-01
    4.2114947e-03  1.0094552e+00]
  [ 8.4139228e-01  5.4529727e-01  8.1881285e-01 ...  1.0031143e+00
    1.9937446e-03  9.9990916e-01]
  [ 9.1022974e-01 -4.0879697e-01  9.4003475e-01 ...  1.0101473e+00
    1.4608215e-03  9.9859720e-01]
  ...
  [ 1.0000000e-12  1.0000000e-12  1.0000000e-12 ...  1.0000000e

Array([[[0.00101055, 0.0009902 , 0.00100665, ..., 0.00098968,
         0.00101313, 0.00099828],
        [0.00101595, 0.00098763, 0.00101333, ..., 0.0009896 ,
         0.00100887, 0.00099511],
        [0.00101997, 0.00099128, 0.00101522, ..., 0.00099178,
         0.00100284, 0.0009948 ],
        ...,
        [0.00097679, 0.00102432, 0.00103066, ..., 0.00100157,
         0.00099764, 0.00099517],
        [0.00096693, 0.00102379, 0.00102864, ..., 0.00100948,
         0.00100275, 0.00101981],
        [0.00095738, 0.00102983, 0.00103375, ..., 0.00099378,
         0.00100386, 0.00100513]],

       [[0.00101055, 0.0009902 , 0.00100665, ..., 0.00098968,
         0.00101313, 0.00099828],
        [0.00101573, 0.00098744, 0.00101352, ..., 0.00098951,
         0.00100863, 0.00099528],
        [0.00102022, 0.00099106, 0.00101522, ..., 0.00099163,
         0.00100293, 0.00099474],
        ...,
        [0.00097679, 0.00102432, 0.00103066, ..., 0.00100157,
         0.00099764, 0.00099517],
        [0.0

In [238]:
CategoricalCrossEntropy(Trans,({'token_ids':padded_text_en['token_ids'][65:68],'padding_mask':padded_text_en['padding_mask'][65:68]},
      {'token_ids':padded_text_fr['token_ids'][65:68],'padding_mask':padded_text_fr['padding_mask'][65:68]}),
              one_hot(Padder.left_shift(jnp.array(padded_text_fr['token_ids'][65:68]),5),1000))
      

Array(6.907756, dtype=float32)

In [156]:
[fr_tokenizer.detokenize(tokens.tolist()) for tokens in jnp.argmax(Trans({'token_ids':padded_text_en['token_ids'][0:10],'padding_mask':padded_text_en['padding_mask'][0:10]},
      {'token_ids':padded_text_fr['token_ids'][0:10],'padding_mask':padded_text_fr['padding_mask'][0:10]}),axis=-1)]

[[[ 9.3224780e-06  1.0000736e+00  3.6200152e-05 ...  1.0001014e+00
    1.2534951e-05  9.9998599e-01]
  [ 8.4145898e-01  5.4033417e-01  8.2183337e-01 ...  9.9993509e-01
    1.5805045e-04  9.9995029e-01]
  [ 9.0932953e-01 -4.1611534e-01  9.3639439e-01 ...  1.0000237e+00
    1.7979540e-04  9.9995142e-01]
  ...
  [ 1.0000000e-12  1.0000000e-12  1.0000000e-12 ...  1.0000000e-12
    1.0000000e-12  1.0000000e-12]
  [ 1.0000000e-12  1.0000000e-12  1.0000000e-12 ...  1.0000000e-12
    1.0000000e-12  1.0000000e-12]
  [ 1.0000000e-12  1.0000000e-12  1.0000000e-12 ...  1.0000000e-12
    1.0000000e-12  1.0000000e-12]]

 [[-8.5690306e-05  1.0000592e+00 -1.4283750e-06 ...  9.9995100e-01
    4.2114949e-05  1.0000945e+00]
  [ 8.4147024e-01  5.4035223e-01  8.2182580e-01 ...  1.0000311e+00
    1.2256415e-04  9.9999911e-01]
  [ 9.0930676e-01 -4.1607332e-01  9.3645090e-01 ...  1.0001014e+00
    2.1986160e-04  9.9998599e-01]
  ...
  [ 1.0000000e-12  1.0000000e-12  1.0000000e-12 ...  1.0000000e-12
    1.0000

['om celuiTT voisardue chaussures première fou cha fou jusqu fatigué chaussures retour me fatigué me fou chaussures me me fou me fou me me fou fatigué fatigué chaussures demandé fatigué fou fou me chaussures me fer cha fatigué cassé fou fou cassé casséelle fou fatigué fou fatigué fou me demandé mees me fou cassé cassé Qui cassé',
 'om celuiTT voisardue chaussures première fou cha fou jusqu fatigué chaussures retour me fatigué me fou chaussures me me fou me fou me me fou fatigué fatigué chaussures demandé fatigué fou fou me chaussures me fer cha fatigué cassé fou fou cassé casséelle fou fatigué fou fatigué fou me demandé mees me fou cassé cassé Qui cassé',
 'om celuiTT voisardue chaussures première fou cha fou jusqu fatigué chaussures retour me fatigué me fou chaussures me me fou me fou me me fou fatigué fatigué chaussures demandé fatigué fou fou me chaussures me fer cha fatigué cassé fou fou cassé casséelle fou fatigué fou fatigué fou me demandé mees me fou cassé cassé Qui cassé',
 'om

In [619]:
fr_tokenizer.detokenize(padded_text_fr['token_ids'][0:10])

['Salut !',
 'Cours !',
 'Courez !',
 'Qui ?',
 'Ça alors !',
 'Au feu !',
 "À l'aide !",
 'Saute .',
 'Ça suffit !',
 'Stop !']

In [620]:
jnp.argmax(Trans({'token_ids':padded_text_en['token_ids'][0:20],'padding_mask':padded_text_en['padding_mask'][0:20]},
      {'token_ids':padded_text_fr['token_ids'][0:20],'padding_mask':padded_text_fr['padding_mask'][0:20]}),axis=-1)

Array([[ 11, 214, 205,  88,   4,  88,  79,  79,  79,  79,  79,  79,  79,
         79,  79,  79],
       [ 11, 165,   9,  88,   4,  88,  79,  79,  79,  79,  79,  79,  79,
         79,  79,  79],
       [ 11, 165,   9,  88,   4,  88,  79,  79,  79,  79,  79,  79,  79,
         79,  79,  79],
       [156,  17,   4,  88,  79,  79,  79,  79,  79,  79,  79,  79,  79,
         79,  79,  79],
       [156,  88,  88,   4,  88,  79,  79,  79,  79,  79,  79,  79,  79,
         79,  79,  79],
       [ 11,  45,  88,  88,   4,  88,  79,  79,  79,  79,  79,  79,  79,
         79,  79,  79],
       [ 11, 381,  25,   8,  35,  88,   4,  88,  79,  79,  79,  79,  79,
         79,  79,  79],
       [ 11, 214, 118,   7,   6,   4,  88,  79,  79,  79,  79,  79,  79,
         79,  79,  79],
       [ 11,   7,  66,  61,  61,  88,   4,  88,  79,  79,  79,  79,  79,
         79,  79,  79],
       [ 11,  86,  82,  88,   4,  88,  79,  79,  79,  79,  79,  79,  79,
         79,  79,  79],
       [ 11,   7,  12, 186,  8

In [582]:
en_text[0:10]

['Hi.',
 'Run!',
 'Run!',
 'Who?',
 'Wow!',
 'Fire!',
 'Help!',
 'Jump.',
 'Stop!',
 'Stop!']

In [583]:
fr_text[0:10]

['Salut!',
 'Cours\u202f!',
 'Courez\u202f!',
 'Qui ?',
 'Ça alors\u202f!',
 'Au feu !',
 "À l'aide\u202f!",
 'Saute.',
 'Ça suffit\u202f!',
 'Stop\u202f!']

In [584]:
fr_tokenizer.detokenize(padded_text_fr['token_ids'][120:130])

['Fous le camp !',
 "Pars d'ici .",
 "Va t'en !",
 'Disparais !',
 'Allez -vous en !',
 'Rentrez à la maison .',
 'Rentre à la maison .',
 'Rentre chez toi .',
 'Rentrez chez vous .',
 'Va doucement !']

In [529]:
en_tokenizer.detokenize(padded_text_en['token_ids'][96:100])

['Call us .', 'Come in .', 'Come in .', 'Come in .']

In [3]:
#padded_text_en['token_ids'][120:130]

In [104]:
# Emb(jnp.array(padded_text_en['token_ids'][:2]))

Array([[[ 0.00373945,  0.98045063,  0.01434524, ...,  1.0075562 ,
          0.02784876,  1.0029453 ],
        [ 0.8286708 ,  0.5332071 ,  0.8206814 , ...,  0.9911458 ,
          0.01865378,  1.0128202 ],
        [ 0.90645593, -0.3918722 ,  0.947119  , ...,  0.99548614,
         -0.02322226,  1.0044475 ],
        ...,
        [-0.65509343, -0.7416719 ,  0.3009365 , ...,  1.0073984 ,
          0.00354829,  1.0133736 ],
        [-0.9794912 ,  0.16063708, -0.6112334 , ...,  1.007398  ,
          0.00365195,  1.0133733 ],
        [-0.3954972 ,  0.921128  , -0.9920058 , ...,  1.0073977 ,
          0.00375561,  1.0133729 ]],

       [[ 0.00373945,  0.98045063,  0.01434524, ...,  1.0075562 ,
          0.02784876,  1.0029453 ],
        [ 0.8428125 ,  0.5345792 ,  0.8314176 , ...,  0.9894027 ,
          0.00636642,  1.0055251 ],
        [ 0.8900799 , -0.42708257,  0.9419847 , ...,  0.99656487,
         -0.0027082 ,  1.0100876 ],
        ...,
        [-0.65509343, -0.7416719 ,  0.3009365 , ...,  

In [633]:
class Translator:
    def __init__(self,model,en_tokenizer,fr_tokenizer,en_padder,fr_padder):
        self.model = model
        self.en_tokenizer = en_tokenizer
        self.fr_tokenizer = fr_tokenizer
        self.en_padder = en_padder
        self.fr_padder = fr_padder
    def __call__(self,text):
        en_tokens = self.en_tokenizer([text])
        #print(en_tokens)
        padded_text_en = self.en_padder(en_tokens)
        print(padded_text_en)
        #fr_tokens = [[3]]
        count = 0
        pred_token = 0
        fr_text = "A"
        while pred_token!=4:
            fr_tokens = self.fr_tokenizer([fr_text])
            padded_text_fr = self.fr_padder(fr_tokens)
            #print(padded_text_fr['token_ids'])
            index=padded_text_fr['token_ids'][0].index(4)
            padded_text_fr['token_ids'][0][index] = 5
            padded_text_fr['padding_mask'][0][index] = 0
            print(padded_text_fr)
            predicted_tokens = jnp.argmax(self.model({'token_ids':padded_text_en['token_ids'],'padding_mask':padded_text_en['padding_mask']},
                   {'token_ids':padded_text_fr['token_ids'],'padding_mask':padded_text_fr['padding_mask']}),axis=-1)
            #print(predicted_tokens)
            pred_token = predicted_tokens[0][count]
            fr_text+=self.fr_tokenizer.detokenize(pred_token.tolist())
            count+=1
        #print(predicted_tokens)
        print([self.fr_tokenizer.detokenize(tokens.tolist()) for tokens in predicted_tokens])
        print(fr_text)

In [634]:
Later = Translator(Trans,en_tokenizer,fr_tokenizer,padding_en,padding_fr)

In [635]:
Later("")

{'token_ids': [[3, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5]], 'padding_mask': [[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]}
{'token_ids': [[3, 213, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5]], 'padding_mask': [[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]}
{'token_ids': [[3, 213, 88, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5]], 'padding_mask': [[1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]}
{'token_ids': [[3, 213, 88, 88, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5]], 'padding_mask': [[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]}
['S !rararararararararararara']
A!!


In [220]:
fr_text[7]R

'Saute.'

In [337]:
en_text[0:4]

['Hi.', 'Run!', 'Run!', 'Who?']

In [336]:
fr_text[0:4]

['Salut!', 'Cours\u202f!', 'Courez\u202f!', 'Qui ?']

In [73]:
import jaxlib
def convert_weights(params):
    params = copy.deepcopy(params)
    for key in params:
        if type(params[key])==jaxlib.xla_extension.ArrayImpl:
            params[key]=params[key].tolist()
        else:
            params[key] = convert_weights(params[key])
    return params
            

In [74]:
model_weights = convert_weights(Trans.params)

In [75]:
weight_string = json.dumps(model_weights)

In [76]:
with open("weights-6-512.json","w") as file:
    file.write(weight_string)

In [552]:
def convert_weights_jax(params):
    params = copy.deepcopy(params)
    for key in params:
        if type(params[key])==list:
            params[key]=jnp.array(params[key])
        else:
            params[key] = convert_weights_jax(params[key])
    return params

In [553]:
with open("weights.json","r") as file:
    weight_string = file.read()
    params = json.loads(weight_string)
    params = convert_weights_jax(params)

In [589]:
#params

In [302]:
# def clip_gradients(params):
#     params = copy.deepcopy(params)
#     for key in params:
#         if type(params[key])==jaxlib.xla_extension.ArrayImpl:
#             params[key] = jnp.clip(params[key],-1.0,1.0)
#             params[key] = jnp.nan_to_num(params[key])
#         else:
#             params[key] = clip_gradients(params[key])
#     return params