In [52]:
import dataclasses
import jax
from flax import struct
from types import MappingProxyType

@dataclasses.dataclass
class Params_template:
    eField: float = struct.field(pytree_node=False)
    Ab: float = struct.field(pytree_node=False)
    kb: float = struct.field(pytree_node=False)
    lifetime: float = struct.field(pytree_node=False)
    vdrift: float = struct.field(pytree_node=False)
    long_diff: float = struct.field(pytree_node=False)
    tran_diff: float = struct.field(pytree_node=False)
    tpc_borders: jax.Array = struct.field(pytree_node=False)
    box: int = struct.field(pytree_node=False)
    birks: int = struct.field(pytree_node=False)
    lArDensity: float = struct.field(pytree_node=False)
    alpha: float = struct.field(pytree_node=False)
    beta: float = struct.field(pytree_node=False)
    MeVToElectrons: float = struct.field(pytree_node=False)
    pixel_pitch: float = struct.field(pytree_node=False)
    n_pixels: tuple = struct.field(pytree_node=False)
    max_radius: int = struct.field(pytree_node=False)
    max_active_pixels: int = struct.field(pytree_node=False)
    drift_length: float = struct.field(pytree_node=False)
    t_sampling: float = struct.field(pytree_node=False)
    time_interval: float = struct.field(pytree_node=False)
    time_padding: float = struct.field(pytree_node=False)
    min_step_size: float = struct.field(pytree_node=False)
    time_max: float = struct.field(pytree_node=False)
    time_window: float = struct.field(pytree_node=False)
    e_charge: float = struct.field(pytree_node=False)
    #: Maximum number of ADC values stored per pixel
    MAX_ADC_VALUES: int = struct.field(pytree_node=False)
    #: Discrimination threshold
    DISCRIMINATION_THRESHOLD: float = struct.field(pytree_node=False)
    #: ADC hold delay in clock cycles
    ADC_HOLD_DELAY: int = struct.field(pytree_node=False)
    #: Clock cycle time in :math:`\mu s`
    CLOCK_CYCLE: float = struct.field(pytree_node=False)
    #: Front-end gain in :math:`mV/ke-`
    GAIN: float = struct.field(pytree_node=False)
    #: Common-mode voltage in :math:`mV`
    V_CM: float = struct.field(pytree_node=False)
    #: Reference voltage in :math:`mV`
    V_REF: float = struct.field(pytree_node=False)
    #: Pedestal voltage in :math:`mV`
    V_PEDESTAL: float = struct.field(pytree_node=False)
    #: Number of ADC counts
    ADC_COUNTS: int = struct.field(pytree_node=False)
    # if readout_noise:
        #: Reset noise in e-
        # self.RESET_NOISE_CHARGE = 900
        # #: Uncorrelated noise in e-
        # self.UNCORRELATED_NOISE_CHARGE = 500
    # else:
    RESET_NOISE_CHARGE: float = struct.field(pytree_node=False)
    UNCORRELATED_NOISE_CHARGE: float = struct.field(pytree_node=False)

def build_params_class(params_with_grad):
    template_fields = dataclasses.fields(Params_template)
    # Removing the pytree_node=False for the variables requiring gradient calculation
    for param in params_with_grad:
        for field in template_fields:
            if field.name == param:
                field.metadata = MappingProxyType({})
                break
    #Dynamically creating the class from the fields and passing it to struct that will itself pass it to dataclass, ouf...
    base_class = type("Params", (object, ), {field.name: field for field in template_fields})
    base_class.__annotations__ = {field.name: field.type for field in template_fields}
    return struct.dataclass(base_class)

Params = build_params_class()

In [53]:
dataclasses.fields(Params)

(Field(name='eField',type=<class 'float'>,default=<dataclasses._MISSING_TYPE object at 0x7f3ec0dcb970>,default_factory=<dataclasses._MISSING_TYPE object at 0x7f3ec0dcb970>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({}),kw_only=False,_field_type=_FIELD),
 Field(name='Ab',type=<class 'float'>,default=<dataclasses._MISSING_TYPE object at 0x7f3ec0dcb970>,default_factory=<dataclasses._MISSING_TYPE object at 0x7f3ec0dcb970>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'pytree_node': False}),kw_only=False,_field_type=_FIELD),
 Field(name='kb',type=<class 'float'>,default=<dataclasses._MISSING_TYPE object at 0x7f3ec0dcb970>,default_factory=<dataclasses._MISSING_TYPE object at 0x7f3ec0dcb970>,init=True,repr=True,hash=None,compare=True,metadata=mappingproxy({'pytree_node': False}),kw_only=False,_field_type=_FIELD),
 Field(name='lifetime',type=<class 'float'>,default=<dataclasses._MISSING_TYPE object at 0x7f3ec0dcb970>,default_factory=<dataclasses._MISSI

In [57]:
Params.__dict__

mappingproxy({'__module__': '__main__',
              '__dict__': <attribute '__dict__' of 'Params' objects>,
              '__weakref__': <attribute '__weakref__' of 'Params' objects>,
              '__doc__': 'Params(eField: float, Ab: float, kb: float, lifetime: float, vdrift: float, long_diff: float, tran_diff: float, tpc_borders: jax.Array, box: int, birks: int, lArDensity: float, alpha: float, beta: float, MeVToElectrons: float, pixel_pitch: float, n_pixels: tuple, max_radius: int, max_active_pixels: int, drift_length: float, t_sampling: float, time_interval: float, time_padding: float, min_step_size: float, time_max: float, time_window: float, e_charge: float, MAX_ADC_VALUES: int, DISCRIMINATION_THRESHOLD: float, ADC_HOLD_DELAY: int, CLOCK_CYCLE: float, GAIN: float, V_CM: float, V_REF: float, V_PEDESTAL: float, ADC_COUNTS: int, RESET_NOISE_CHARGE: float, UNCORRELATED_NOISE_CHARGE: float)',
              '__annotations__': {'eField': float,
               'Ab': float,
           

In [38]:
class Params_template:
    eField: float = struct.field(pytree_node=False)
    Ab: float = struct.field(pytree_node=False)
    kb: float = struct.field(pytree_node=False)
    lifetime: float = struct.field(pytree_node=False)
    vdrift: float = struct.field(pytree_node=False)
    long_diff: float = struct.field(pytree_node=False)
    tran_diff: float = struct.field(pytree_node=False)
    tpc_borders: jax.Array = struct.field(pytree_node=False)
    box: int = struct.field(pytree_node=False)
    birks: int = struct.field(pytree_node=False)
    lArDensity: float = struct.field(pytree_node=False)
    alpha: float = struct.field(pytree_node=False)
    beta: float = struct.field(pytree_node=False)
    MeVToElectrons: float = struct.field(pytree_node=False)
    pixel_pitch: float = struct.field(pytree_node=False)
    n_pixels: tuple = struct.field(pytree_node=False)
    max_radius: int = struct.field(pytree_node=False)
    max_active_pixels: int = struct.field(pytree_node=False)
    drift_length: float = struct.field(pytree_node=False)
    t_sampling: float = struct.field(pytree_node=False)
    time_interval: float = struct.field(pytree_node=False)
    time_padding: float = struct.field(pytree_node=False)
    min_step_size: float = struct.field(pytree_node=False)
    time_max: float = struct.field(pytree_node=False)
    time_window: float = struct.field(pytree_node=False)
    e_charge: float = struct.field(pytree_node=False)
    #: Maximum number of ADC values stored per pixel
    MAX_ADC_VALUES: int = struct.field(pytree_node=False)
    #: Discrimination threshold
    DISCRIMINATION_THRESHOLD: float = struct.field(pytree_node=False)
    #: ADC hold delay in clock cycles
    ADC_HOLD_DELAY: int = struct.field(pytree_node=False)
    #: Clock cycle time in :math:`\mu s`
    CLOCK_CYCLE: float = struct.field(pytree_node=False)
    #: Front-end gain in :math:`mV/ke-`
    GAIN: float = struct.field(pytree_node=False)
    #: Common-mode voltage in :math:`mV`
    V_CM: float = struct.field(pytree_node=False)
    #: Reference voltage in :math:`mV`
    V_REF: float = struct.field(pytree_node=False)
    #: Pedestal voltage in :math:`mV`
    V_PEDESTAL: float = struct.field(pytree_node=False)
    #: Number of ADC counts
    ADC_COUNTS: int = struct.field(pytree_node=False)
    # if readout_noise:
        #: Reset noise in e-
        # self.RESET_NOISE_CHARGE = 900
        # #: Uncorrelated noise in e-
        # self.UNCORRELATED_NOISE_CHARGE = 500
    # else:
    RESET_NOISE_CHARGE: float = struct.field(pytree_node=False)
    UNCORRELATED_NOISE_CHARGE: float = struct.field(pytree_node=False)

In [66]:
import optax
opt = optax.multi_transform({'eField': optax.adam(0.1), 'Ab': optax.adam(0.01)}, ['eField', 'Ab'])
print(opt)

GradientTransformationExtraArgs(init=<function multi_transform.<locals>.init_fn at 0x7f3e1e5db520>, update=<function multi_transform.<locals>.update_fn at 0x7f3e1e5db5b0>)


In [68]:
from transformers import FlaxBertForSequenceClassification
model = FlaxBertForSequenceClassification.from_pretrained('bert-base-uncased')
model.params.keys()

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

flax_model.msgpack:   0%|          | 0.00/438M [00:00<?, ?B/s]

2024-03-29 13:54:33.762247: E external/xla/xla/stream_executor/cuda/cuda_driver.cc:274] failed call to cuInit: CUDA_ERROR_UNKNOWN: unknown error
CUDA backend failed to initialize: FAILED_PRECONDITION: No visible GPU devices. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Some weights of the model checkpoint at bert-base-uncased were not used when initializing FlaxBertForSequenceClassification: {('cls', 'predictions', 'transform', 'dense', 'bias'), ('cls', 'predictions', 'transform', 'LayerNorm', 'scale'), ('cls', 'predictions', 'transform', 'LayerNorm', 'bias'), ('cls', 'predictions', 'transform', 'dense', 'kernel'), ('cls', 'predictions', 'bias')}
- This IS expected if you are initializing FlaxBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxBertForSequenceClassification fro

dict_keys(['bert', 'classifier'])

In [71]:
jax.tree_util.tree_structure(model.params)

PyTreeDef({'bert': {'embeddings': {'LayerNorm': {'bias': *, 'scale': *}, 'position_embeddings': {'embedding': *}, 'token_type_embeddings': {'embedding': *}, 'word_embeddings': {'embedding': *}}, 'encoder': {'layer': {'0': {'attention': {'output': {'LayerNorm': {'bias': *, 'scale': *}, 'dense': {'bias': *, 'kernel': *}}, 'self': {'key': {'bias': *, 'kernel': *}, 'query': {'bias': *, 'kernel': *}, 'value': {'bias': *, 'kernel': *}}}, 'intermediate': {'dense': {'bias': *, 'kernel': *}}, 'output': {'LayerNorm': {'bias': *, 'scale': *}, 'dense': {'bias': *, 'kernel': *}}}, '1': {'attention': {'output': {'LayerNorm': {'bias': *, 'scale': *}, 'dense': {'bias': *, 'kernel': *}}, 'self': {'key': {'bias': *, 'kernel': *}, 'query': {'bias': *, 'kernel': *}, 'value': {'bias': *, 'kernel': *}}}, 'intermediate': {'dense': {'bias': *, 'kernel': *}}, 'output': {'LayerNorm': {'bias': *, 'scale': *}, 'dense': {'bias': *, 'kernel': *}}}, '10': {'attention': {'output': {'LayerNorm': {'bias': *, 'scale': *