# Setup

In [7]:
from google.colab import files

In [6]:
!rm train.py

In [8]:
# Run 3 times to upload model.py, seqio_utils.py, train.py, train.tfrecord

files.upload()

Saving train.py to train.py


{'train.py': b'"""Bianry for training bytecodes classifier transformer."""\n\nimport functools\nimport logging\nfrom typing import Any, Callable, Sequence, Type\n\nimport chex\nfrom clu import metric_writers\nfrom clu import metrics\nfrom clu import periodic_actions\nimport flax\nimport flax.linen as nn\nfrom flax.training import train_state\nimport jax\nfrom jax.experimental import mesh_utils\nimport jax.numpy as jnp\nimport ml_collections\nimport numpy as np\nimport optax\nfrom orbax import checkpoint as orbax_checkpoint\n\nimport model\nimport seqio_utils\n\n\nLearningRateSchedule = Callable[[chex.Numeric], chex.Numeric]\n\n\nclass Config(ml_collections.ConfigDict):\n  """Configuration for training bytecodes classifier transformer."""\n\n  model_name: str = \'\'\n  seed: int = 42\n\n  num_train_steps: int = 1_000\n  initial_step: int = 1\n  log_loss_every_steps = 25\n  eval_every_steps = 250\n  checkpoint_every_steps = 1_000\n\n  learning_rate: float = 0.001\n  lr_warmup_steps: int 

In [1]:
!ls

model  model.py  __pycache__  sample_data  seqio_utils.py  train.py  train.tfrecord


In [5]:
!pip3 install seqio

[31mERROR: Could not find a version that satisfies the requirement xnlp (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for xnlp[0m[31m
[0m

In [1]:
import logging
import ml_collections
import jax

import model
import seqio_utils
import train

In [2]:
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),
 TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),
 TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),
 TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),
 TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

# Test Run

In [11]:
config = ml_collections.ConfigDict()
config.seed = 42
config.model_name = 'bytecode_transformer_functions'

config.num_train_steps = 100
config.initial_step = 1
config.learning_rate = 1e-3
config.lr_warmup_steps = 0
config.lr_decay_steps = 0

config.per_device_batch_size = 1
config.train_tfrecord_path = 'train.tfrecord'
config.max_vocab_size = 256 + 128
config.seqlen = 2048

config.log_loss_every_steps = 100
config.eval_every_steps = 100

config.num_layers = 1
config.num_heads = 4
config.embed_dim = 256
config.transformer_mlp_dim = 512
config.classifier_mlp_dim = 32
# config.conv_layers = 1

config.workdir = 'model'

In [8]:
logging.basicConfig(filename='train.log', level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', force=True)

In [12]:
train.TrainClassifier(config, allow_duplicate_tasks=True)

INFO:2025-04-08 20:27:13,217:jax._src.mesh_utils:82: Reordering mesh to physical ring order on single-tray TPU v2/v3.


In [13]:
!cat train.log

2025-04-08 20:26:44,547 - INFO - Training with config: classifier_mlp_dim: 32
embed_dim: 256
eval_every_steps: 10
initial_step: 1
learning_rate: 0.001
log_loss_every_steps: 10
lr_decay_steps: 0
lr_warmup_steps: 0
max_vocab_size: 384
model_name: bytecode_transformer_functions
num_heads: 4
num_layers: 1
num_train_steps: 10
per_device_batch_size: 1
seed: 42
seqlen: 2048
train_tfrecord_path: train.tfrecord
transformer_mlp_dim: 512
workdir: model

2025-04-08 20:26:44,549 - INFO - Device count: 8
2025-04-08 20:26:44,549 - INFO - Using global batch size: 8
2025-04-08 20:26:44,620 - INFO - Reordering mesh to physical ring order on single-tray TPU v2/v3.
2025-04-08 20:26:44,621 - INFO - Using seed: 42
2025-04-08 20:26:44,716 - INFO - Total parameters: 633665
2025-04-08 20:26:44,719 - INFO - [Hyperparameters] {'classifier_mlp_dim': 32, 'embed_dim': 256, 'eval_every_steps': 10, 'initial_step': 1, 'learning_rate': 0.001, 'log_loss_every_steps': 10, 'lr_decay_steps': 0, 'lr_warmup_steps': 0, 'max_v

# Configs for Experiments

## Function Classification

In [16]:
n_examples = 28_080

In [17]:
batch_size = 128
per_device_batch_size = batch_size // jax.local_device_count()
n_epochs = 16
train_steps = n_epochs * (n_examples // batch_size)
log_loss_every_steps = train_steps // n_epochs  # Once per epoch
eval_every_steps = train_steps #// n_epochs  # Once per epoch
checkpoint_every_steps = train_steps // n_epochs

In [14]:
!mkdir -p model

In [18]:
config = ml_collections.ConfigDict()
config.seed = 42
config.model_name = 'bytecode_transformer_functions'

config.num_train_steps = train_steps
config.initial_step = 1
config.learning_rate = 1e-3
config.lr_warmup_steps = 0
config.lr_decay_steps = 0

config.per_device_batch_size = per_device_batch_size
config.train_tfrecord_path = '/path/to/train.trfrecord'
config.test_tfrecord_path = '/path/to/test/test.tfrecord'
config.max_vocab_size = 256 + 128
config.seqlen = 2048

config.log_loss_every_steps = log_loss_every_steps
config.eval_every_steps = eval_every_steps

config.checkpoint_every_steps = checkpoint_every_steps
config.checkpoint_dir = '/path/to/ckpt_dir/'

config.num_layers = 1
config.num_heads = 4
config.embed_dim = 256
config.transformer_mlp_dim = 512
config.classifier_mlp_dim = 32

config.workdir = 'model'

## Script Classification

In [None]:
n_examples_16k = 150891

n_examples = n_examples_16k

In [None]:
batch_size = 8
per_device_batch_size = batch_size // jax.local_device_count()
n_epochs = 1
train_steps = n_epochs * (n_examples // batch_size)
log_loss_every_steps = train_steps // n_epochs # Once per epoch
eval_every_steps = train_steps #// n_epochs  # Once per epoch
checkpoint_every_steps = train_steps // n_epochs

In [None]:
config = ml_collections.ConfigDict()
config.seed = 42
config.model_name = 'bytecode_transformer_functions'

config.num_train_steps = train_steps
config.initial_step = 1
config.learning_rate = 1e-3
config.lr_warmup_steps = 0
config.lr_decay_steps = 1.5 * train_steps

config.per_device_batch_size = per_device_batch_size
config.train_tfrecord_path = '/path/to/train.trfrecord'
config.test_tfrecord_path = '/path/to/test/test.tfrecord'
config.max_vocab_size = 256 + 128
config.seqlen = 16_384

config.log_loss_every_steps = log_loss_every_steps
config.eval_every_steps = eval_every_steps

config.checkpoint_every_steps = checkpoint_every_steps
config.checkpoint_dir = '/path/to/ckpt_dir/'

config.num_layers = 1
config.num_heads = 4
config.embed_dim = 256
config.transformer_mlp_dim = 512
config.classifier_mlp_dim = 32
# Set to 1, 2
# config.conv_blocks = 1

config.workdir = 'model'