#### **Import Libraries**

In [None]:
import jax 
import jax.numpy as jnp 
from dataclasses import dataclass
from typing import Callable, Tuple, Dict
import optax 
import matplotlib.pyplot as plt 
from datasets import load_dataset
from transformers import FlaxAutoModelForMaskedLM, AutoTokenizer
from transformers import FlaxAutoModelForSequenceClassification, AutoConfig
from rfp.losses import Supervised_Loss, softmax_cross_entropy, Cluster_Loss
from rfp.train import Trainer  

#### **Model**

In [None]:
model_checkpoint = "roberta-base"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
tokenizer.model_max_length = 512
print(tokenizer.model_max_length)
config = AutoConfig.from_pretrained(model_checkpoint, num_labels=2)
model = FlaxAutoModelForSequenceClassification.from_pretrained(model_checkpoint, config=config)

#### **Tokenize Function**

In [None]:
def tokenizer_function(batch):
    # Assuming 'text' is the field you want to tokenize; adjust accordingly.
    text_data = batch['text']  # Adjust this if another text field should be tokenized.
    
    # Tokenize the text data
    tokens = tokenizer(text_data, return_tensors="jax", padding='max_length', truncation=True)
    
    # Convert tokenized data to JAX arrays and return
    tokens = {key: jnp.array(value) for key, value in tokens.items() if key in ['input_ids', 'attention_mask', 'label']}
    
    return tokens


#### **Load and Process Dataset**

In [None]:
original_dataset = load_dataset("ppower1/instrument")['train']
original_dataset = original_dataset.map(tokenizer_function, batched=True)

In [None]:
len(original_dataset)

#### **Define Forward Pass**

In [None]:
def rfp(params, tokens):
    return model(tokens['input_ids'], tokens['attention_mask'], params=params).logits, 0.0

In [None]:
supervised_loss = Supervised_Loss(softmax_cross_entropy,  rfp)                    

In [None]:
supervised_loss(model.params, original_dataset, jnp.array(original_dataset['label']).astype(jnp.float32), jnp.ones(shape=(len(original_dataset), 1)).astype(jnp.float32))

In [None]:
original_dataset