In [109]:
import jax.numpy as jnp
from jax import random
from jax import lax as jlax
from jax.tree_util import register_pytree_node_class

In [None]:
@register_pytree_node_class
class Parameter:
    def __init__(self,name,value):
        self.name = name
        self.value = value
        self.shape = value.shape
    def __sub__(self,param):
        if isinstance(param,Parameter):
            return Parameter(self.name,self.value-param.value)
        raise TypeError(f"unsupported operand type(s) for -: {type(param)} and 'Parameter'")
    def __add__(self,other):
        if isinstance(other,Parameter):
            return Parameter(self.name,self.value+other.value)
        if isinstance(other,float):
            return Parameter(self.name,self.value+other)
        raise TypeError(f"unsupported operand type(s) for +: {type(other)} and 'Parameter'")
    def __mul__(self,factor):
        if isinstance(factor,float):
            return Parameter(self.name,factor*self.value)
        raise TypeError(f'Cannot multiply a Parameter with {type(factor)}')
    def __rmul__(self,factor):
        if isinstance(factor,float):
            return Parameter(self.name,factor*self.value)
        raise TypeError(f'Cannot multiply a Parameter with {type(factor)}')
    def __pow__(self,factor):
        return Parameter(self.name,self.value**factor)
    def __truediv__(self,other):
        if isinstance(other,Parameter):
            return Parameter(self.name,self.value/other.value)
        if isinstance(other,float):
            return Parameter(self.name,self.value/other)
        raise TypeError(f'Cannot divide a Parameter with {type(factor)}')
    def tree_flatten(self):
        children = (self.value,)
        aux_data = (self.name,)
        return (children, aux_data)
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*aux_data,*children)

In [175]:
@register_pytree_node_class
class LinearParams:
    def __init__(self,name,weights):
        self.name = name
        if isinstance(weights,Parameter):
            self.weights = weights
        else:
            self.weights = Parameter("W",weights)
    def __sub__(self,other):
        if isinstance(other,LinearParams):
            return LinearParams(self.name,self.weights-other.weights)
        raise TypeError(f"unsupported operand type(s) for -: {type(other)} and 'LinearParams'")
    def __add__(self,other):
        if isinstance(other,LinearParams) :
            return LinearParams(self.name,self.weights+other.weights)
        if isinstance(other,float) :
            return LinearParams(self.name,self.weights+other)
        raise TypeError(f"unsupported operand type(s) for +: {type(other)} and 'LinearParams'")
    def __mul__(self,other):
        if isinstance(other,float):
            return LinearParams(self.name,self.weights*other)
        raise TypeError(f"Cannot multiply a 'LinearParams' with {type(other)}")
    def __rmul__(self,other):
        if isinstance(other,float):
            return LinearParams(self.name,self.weights*other)
        raise TypeError(f"Cannot multiply a 'LinearParams' with {type(other)}")
    def __truediv__(self,other):
        if isinstance(other,LinearParams) :
            return LinearParams(self.name,self.weights/other.weights)
        if isinstance(other,float):
            return LinearParams(self.name,self.weights/other)
        raise TypeError(f"Cannot divide a 'LinearParams' with {type(other)}")
    def __pow__(self,factor):
        return LinearParams(self.name,self.weights**factor)
    def tree_flatten(self):
        children = (self.weights,)
        aux_data = (self.name,)
        return (children, aux_data)
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*aux_data,*children)

In [207]:
@register_pytree_node_class
class FeedForwardParams:
    def __init__(self,name,weights,bias):
        self.name = name
        if isinstance(weights,Parameter):
            self.weights = weights
        else:
            self.weights = Parameter("W",weights)
        if isinstance(bias,Parameter):
            self.bias = bias
        else:
            self.bias = Parameter("bais",bias)
    def __sub__(self,other):
        if isinstance(other,FeedForwardParams) :
            return FeedForwardParams(self.name,self.weights-other.weights,self.bias-other.bias)
        raise TypeError(f"unsupported operand type(s) for -: {type(other)} and 'FeedForwardParams'")
    def __add__(self,other):
        if isinstance(other,FeedForwardParams):
            return FeedForwardParams(self.name,self.weights+other.weights,self.bias+other.bias)
        if isinstance(other,float):
            return FeedForwardParams(self.name,self.weights+other,self.bias+other)
        raise TypeError(f"unsupported operand type(s) for +: {type(other)} and 'FeedForwardParams'")
    def __mul__(self,other):
        if isinstance(other,float):
            return FeedForwardParams(self.name,self.weights*other,self.bias*other)
        raise TypeError(f"Cannot multiply a 'FeedForwardParams' with {type(other)}")
    def __rmul__(self,other):
        if isinstance(other,float):
            return FeedForwardParams(self.name,self.weights*other,self.bias*other)
        raise TypeError(f"Cannot multiply a 'FeedForwardParams' with {type(other)}")
    def __truediv__(self,other):
        if isinstance(other,FeedForwardParams):
            return FeedForwardParams(self.name,self.weights/other.weights,self.bias/other.bias)
        if isinstance(other,float):
            return FeedForwardParams(self.name,self.weights/other,self.bias/other)
        raise TypeError(f"Cannot divide a 'FeedForwardParams' with {type(other)}")
    def __pow__(self,factor):
        return FeedForwardParams(self.name,self.weights**factor,self.bias**factor)
    def tree_flatten(self):
        children = (self.weights,self.bias,)
        aux_data = (self.name,)
        return (children, aux_data)
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*aux_data,*children)
    

In [177]:
@register_pytree_node_class
class AttentionParams:
    def __init__(self,name,w_k,w_q,w_v):
        self.name = name
        if isinstance(w_k,Parameter):
            self.w_k = w_k
        else:
            self.w_k = Parameter("w_k",w_k)
        if isinstance(w_q,Parameter):
            self.w_q = w_q
        else:
            self.w_q = Parameter("w_q",w_q)
        if isinstance(w_v,Parameter):
            self.w_v = w_v
        else:
            self.w_v = Parameter("w_v",w_v)
        
    def __sub__(self,other):
        if isinstance(other,AttentionParams):
            return AttentionParams(self.name,self.w_k-other.w_k,self.w_q-other.w_q,self.w_v-other.w_v)
        raise TypeError(f"unsupported operand type(s) for -: {type(other)} and 'AttentionParams'")
    def __add__(self,other):
        if isinstance(other,AttentionParams):
            return AttentionParams(self.name,self.w_k+other.w_k,self.w_q+other.w_q,self.w_v+other.w_v)
        if isinstance(other,float):
            return AttentionParams(self.name,self.w_k+other,self.w_q+other,self.w_v+other)
        raise TypeError(f"unsupported operand type(s) for +: {type(other)} and 'AttentionParams'")
    def __mul__(self,other):
        if isinstance(other,float):
            return AttentionParams(self.name,self.w_k*other,self.w_q*other,self.w_v*other)
        raise TypeError(f"Cannot multiply a 'AttentionParams' with {type(other)}")
    def __truediv__(self,other):
        if isinstance(other,AttentionParams):
            return AttentionParams(self.name,self.w_k/other.w_k,self.w_q/other.w_q,self.w_v/other.w_v)
        if isinstance(other,float):
            return AttentionParams(self.name,self.w_k/other,self.w_q/other,self.w_v/other)
        raise TypeError(f"Cannot divide a 'AttentionParams' with {type(other)}")
    def __rmul__(self,other):
        if isinstance(other,float):
            return AttentionParams(self.name,self.w_k*other,self.w_q*other,self.w_v*other)
        raise TypeError(f"Cannot multiply a 'AttentionParams' with {type(other)}")
    def __pow__(self,factor):
        return AttentionParams(self.name,self.w_k**factor,self.w_q**factor,self.w_v**factor)
    def tree_flatten(self):
        children = (self.w_k,self.w_q,self.w_v,)
        aux_data = (self.name,)
        return (children, aux_data)
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*aux_data,*children)

In [195]:
@register_pytree_node_class
class MultiHeadAttentionParams:
    def __init__(self,name,weights,heads:list[AttentionParams]):
        self.name = name
        if isinstance(weights,Parameter):
            self.weights = weights
        else:
            self.weights = Parameter("Wo",weights)
        self.heads = heads
        self.num_heads = len(heads)
    def add(self,head1,head2):
        return head1+head2
    def subtract(self,head1,head2):
        return head1-head2
    def multiply(self,val,head):
        return val*head
    def divide(self,head,val):
        return head/val
    def pow(self,val,head):
        return head**val
    def __sub__(self,other):
        if isinstance(other,MultiHeadAttentionParams):
            heads = list(map(self.subtract,self.heads,other.heads))
            return MultiHeadAttentionParams(self.name,self.weights-other.weights,heads)
        raise TypeError(f"unsupported operand type(s) for -: {type(other)} and 'MultiHeadAttentionParams'")
    def __add__(self,other):
        if isinstance(other,MultiHeadAttentionParams) :
            heads = list(map(self.add,self.heads,other.heads))
            return MultiHeadAttentionParams(self.name,self.weights+other.weights,heads)
        if isinstance(other,float):
            heads = list(map(self.add,self.heads,[other]*self.num_heads))
            return MultiHeadAttentionParams(self.name,self.weights+other,heads)
        raise TypeError(f"unsupported operand type(s) for +: {type(other)} and 'MultiHeadAttentionParams'")
    def __mul__(self,other):
        if isinstance(other,float):
            heads = list(map(self.multiply,[other]*self.num_heads,self.heads))
            return MultiHeadAttentionParams(self.name,self.weights*other,heads)
        raise TypeError(f"Cannot multiply a 'MultiHeadAttentionParams' with {type(other)}")
    def __truediv__(self,other):
        if isinstance(other,MultiHeadAttentionParams) :
            heads = list(map(self.divide,self.heads,other.heads))
            return MultiHeadAttentionParams(self.name,self.weights/other.weights,heads)
        if isinstance(other,float):
            heads = list(map(self.divide,self.heads,[other]*self.num_heads))
            return MultiHeadAttentionParams(self.name,self.weights/other,heads)
        raise TypeError(f"Cannot multiply a 'MultiHeadAttentionParams' with {type(other)}")
    def __rmul__(self,other):
        if isinstance(other,float):
            heads = list(map(self.multiply,[other]*self.num_heads,self.heads))
            return MultiHeadAttentionParams(self.name,self.weights*other,heads)
        raise TypeError(f"Cannot multiply a 'MultiHeadAttentionParams' with {type(other)}")
    def __pow__(self,factor):
        heads = list(map(self.pow,[factor]*self.num_heads,self.heads))
        return MultiHeadAttentionParams(self.name,self.weights**factor,heads)
    def tree_flatten(self):
        children = (self.weights,self.heads,)
        aux_data = (self.name,)
        return (children, aux_data)
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*aux_data,*children)

In [232]:
@register_pytree_node_class
class ModuleParams:
    def __init__(self,name,components):
        self.name = name
        self.components = components
        self.num_comps = len(components)
    def multiply(self,val,comp):
        return val*comp
    def subtract(self,comp1,comp2):
        return comp1-comp2
    def add(self,comp1,comp2):
        return comp1+comp2
    def pow(self,val,comp):
        return comp**val
    def divide(self,comp,val):
        return comp/val
    def __sub__(self,other):
        if isinstance(other,ModuleParams):
            comps = list(map(self.subtract,self.components,other.components))
            return ModuleParams(self.name,comps)
        raise TypeError(f"unsupported operand type(s) for -: {type(other)} and 'ModuleParams'")
    def __add__(self,other):
        if isinstance(other,ModuleParams) :
            comps = list(map(self.add,self.components,other.components))
            return ModuleParams(self.name,comps)
        if isinstance(other,float):
            comps = list(map(self.add,self.components,[other]*self.num_comps))
            return ModuleParams(self.name,comps)
        raise TypeError(f"unsupported operand type(s) for +: {type(other)} and 'ModuleParams'")
    def __mul__(self,other):
        if isinstance(other,float):
            comps = list(map(self.multiply,[other]*self.num_comps,self.components))
            return ModuleParams(self.name,comps)
        raise TypeError(f"Cannot multiply a 'ModuleParams' with {type(other)}")
    def __truediv__(self,other):
        if isinstance(other,ModuleParams) :
            comps = list(map(self.divide,self.components,other.components))
            return ModuleParams(self.name,comps)
        if isinstance(other,float):
            comps = list(map(self.divide,self.components,[other]*self.num_comps))
            return ModuleParams(self.name,comps)
        raise TypeError(f"Cannot divide a 'ModuleParams' with {type(other)}")
    def __rmul__(self,other):
        if isinstance(other,float):
            comps = list(map(self.multiply,[other]*self.num_comps,self.components))
            return ModuleParams(self.name,comps)
        raise TypeError(f"Cannot multiply a 'ModuleParams' with {type(other)}")
    def __pow__(self,factor):
        comps = list(map(self.pow,[factor]*self.num_comps,self.components))
        return ModuleParams(self.name,comps)
    def tree_flatten(self):
        children = (self.components,)
        aux_data = (self.name,)
        return (children, aux_data)
    @classmethod
    def tree_unflatten(cls, aux_data, children):
        return cls(*aux_data,*children)

In [72]:
weights1 = Parameter("W",random.normal(random.key(29),shape=(24,16)))

In [73]:
weights2 = Parameter("dw",random.normal(random.key(29),shape=(24,16)))

In [198]:
heads1 = []
heads2 = []
import random as rnd
for i in range(4):
    heads2.append(AttentionParams(f"H{i}",
                                 random.normal(random.key(rnd.randint(0,100)),shape=(24,16)),
                                 random.normal(random.key(rnd.randint(0,100)),shape=(24,16)),
                                 random.normal(random.key(rnd.randint(0,100)),shape=(24,16))
                                )
                )

In [208]:
mha1 = MultiHeadAttentionParams("MHA1",heads1)
mha2 = MultiHeadAttentionParams("MHA2",heads2)
ff1 = FeedForwardParams("FF1",
                       random.normal(random.key(rnd.randint(0,100)),shape=(24,16)),
                       random.normal(random.key(rnd.randint(0,100)),shape=(24,16)))
ff2 = FeedForwardParams("FF2",
                       random.normal(random.key(rnd.randint(0,100)),shape=(24,16)),
                       random.normal(random.key(rnd.randint(0,100)),shape=(24,16)))


In [233]:
comps1 = [ff2,mha2,mha1,ff1]
comps2 = [ff1,mha1,mha2,ff2]

In [234]:
Module1 = ModuleParams("Encoder",comps1)
Module2 = ModuleParams("Encoder",comps2)

In [235]:
Module3 = Module1-0.001*Module2

In [236]:
Module3.components

[<__main__.FeedForwardParams at 0x7fafc4067d30>,
 <__main__.MultiHeadAttentionParams at 0x7fafc4067c70>,
 <__main__.MultiHeadAttentionParams at 0x7fafc4067c10>,
 <__main__.FeedForwardParams at 0x7fafc4067b50>]