#### **Import Libraries**

In [1]:
import jax 
import jax.numpy as jnp 
from transformers import AutoTokenizer, FlaxAutoModelForSequenceClassification
from datasets import load_dataset
import optax 

#### **Plotting**

In [2]:
import matplotlib as mpl
from matplotlib import pyplot as plt
from matplotlib import rcParams
from matplotlib import font_manager

# Configure matplotlib parameters
rcParams['image.interpolation'] = 'nearest'
rcParams['image.cmap'] = 'viridis'
rcParams['axes.grid'] = False

# Enable inline plotting for Jupyter notebooks and use SVG format for higher quality
%matplotlib inline
%config InlineBackend.figure_format = 'svg'

# Set the style for plots
plt.style.use('seaborn-v0_8-dark-palette')

# Define the location of the custom font files
font_location = './../../styles/Newsreader'

# Find all font files in the specified location
font_files = font_manager.findSystemFonts(fontpaths=font_location)

# Print the font location and the first font file found for verification
print(f"Font location: {font_location}")
print(f"First font file: {font_files[0]}")

# Add all the found font files to the font manager
for font_file in font_files:
    font_manager.fontManager.addfont(font_file)

# Set the default font family to the custom font
plt.rcParams["font.family"] = "Newsreader"


Font location: ./../../styles/Newsreader
First font file: /home/ubuntu/llmftax/styles/Newsreader/static/Newsreader_14pt/Newsreader_14pt-ExtraLight.ttf


#### **Parameters**

In [3]:
model_id = 'roberta-base'
epochs = 10
batch_size = 32 

#### **Tokenizer**

In [4]:
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.model_max_length = 512

def tokenizer_function(example):
    return tokenizer(example['text'], truncation=True, padding='max_length')



#### **Model**

In [5]:
model = FlaxAutoModelForSequenceClassification.from_pretrained(model_id,
                                                           num_labels=2)

2024-05-16 16:21:15.725030: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.4.131). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
Some weights of the model checkpoint at roberta-base were not used when initializing FlaxRobertaForSequenceClassification: {('lm_head', 'dense', 'kernel'), ('lm_head', 'dense', 'bias'), ('lm_head', 'layer_norm', 'bias'), ('lm_head', 'bias'), ('lm_head', 'layer_norm', 'scale')}
- This IS expected if you are initializing FlaxRobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxRobertaForSeq

In [6]:
original_dataset = load_dataset("ppower1/instrument")['train']
device = str(jax.devices()[0])
original_dataset = original_dataset.with_format("jax", device=device)

def check_prefix(example):
    example['type_indicator'] = 1 if example['text'].startswith('Yes') else 0
    return example
original_dataset = original_dataset.map(check_prefix)
dataset = original_dataset.train_test_split(test_size=0.5, seed=42)
tokenized_dataset = dataset.map(tokenizer_function, batched=True)
tokenized_dataset = tokenized_dataset.remove_columns(['text', 'treated text', 'control text', 'raw_label'])



In [7]:
def loss_fn(params, batch):
    # Perform a forward pass through the model to get the logits
    logits = model(params=params, input_ids=batch['input_ids'], attention_mask=batch['attention_mask']).logits
    
    # Get the true labels from the batch
    labels = batch['label']
    
    # Compute the two-class cross-entropy loss
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels)
    
    # Return the average loss across the batch
    return jnp.mean(loss)

In [8]:
opt = optax.sgd(learning_rate=1e-4)
params = model.params
opt_state = opt.init(params)
for epoch in range(epochs):
    for batch in tokenized_dataset['train'].iter(batch_size=batch_size):
        loss, grads = jax.value_and_grad(loss_fn)(params, batch)
        updates, opt_state = opt.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        print(loss)

0.80822366


2024-05-16 16:21:46.581820: W external/tsl/tsl/framework/bfc_allocator.cc:482] Allocator (GPU_0_bfc) ran out of memory trying to allocate 192.00MiB (rounded to 201326592)requested by op 
2024-05-16 16:21:46.582628: W external/tsl/tsl/framework/bfc_allocator.cc:494] ****************************************************************************************************
E0516 16:21:46.582682    9021 pjrt_stream_executor_client.cc:2826] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 201326592 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:  192.00MiB
              constant allocation:         0B
        maybe_live_out allocation:  192.00MiB
     preallocated temp allocation:         0B
                 total allocation:  384.00MiB
              total fragmentation:         0B (0.00%)
Peak buffers:
	Buffer 1:
		Size: 192.00MiB
		Entry Parameter Subshape: f32[32,512,3072]

	Buffer 2:
		Size: 192.00MiB
	

ValueError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 201326592 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:  192.00MiB
              constant allocation:         0B
        maybe_live_out allocation:  192.00MiB
     preallocated temp allocation:         0B
                 total allocation:  384.00MiB
              total fragmentation:         0B (0.00%)
Peak buffers:
	Buffer 1:
		Size: 192.00MiB
		Entry Parameter Subshape: f32[32,512,3072]
		==========================

	Buffer 2:
		Size: 192.00MiB
		XLA Label: fusion
		Shape: f32[32,512,3072]
		==========================



In [None]:
loss