In [3]:
import torch
from torch.utils.data import DataLoader
from transformers import (
    AutoTokenizer,
    AutoConfig,
    DataCollatorForLanguageModeling,
    AutoModelForMaskedLM,
    AdamW
)
from datasets import load_dataset

In [4]:
checkpoint = "albert/albert-base-v2"
tokenizer_checkpoint = "albert/albert-base-v2"
dataset_name = "xu-song/cc100-samples"

In [27]:
dataset = load_dataset(dataset_name, "en", split="train[:100%]")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_checkpoint)

def tokenize_function(example):
    return tokenizer(
        example["text"],
        truncation=True,
        padding="max_length",
        max_length=50
    )

# Tokenize and split dataset
dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
dataset = dataset.train_test_split(test_size=0.2)
train_data = dataset["train"]

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=True,
    mlm_probability=0.15
)

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

In [29]:
display(dataset)

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 8000
    })
    test: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 2000
    })
})

In [30]:
# Load teacher model and ensure it outputs hidden states.
teacher_model = AutoModelForMaskedLM.from_pretrained(checkpoint)
teacher_model.config.output_hidden_states = True
teacher_model.eval()  # Set teacher in evaluation mode
for param in teacher_model.parameters():
    param.requires_grad = False

Some weights of the model checkpoint at albert/albert-base-v2 were not used when initializing AlbertForMaskedLM: ['albert.pooler.bias', 'albert.pooler.weight']
- This IS expected if you are initializing AlbertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing AlbertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [17]:
# Define a small student configuration (BERT-like) for distillation.
config = AutoConfig.from_pretrained(
    pretrained_model_name_or_path=checkpoint,
    hidden_size=128,
    num_hidden_layers=2,
    num_attention_heads=2,
    intermediate_size=384,
    hidden_act="gelu",
)
student_model = AutoModelForMaskedLM.from_config(config)
student_model.config.output_hidden_states = True

In [31]:
train_dataloader = DataLoader(train_data, batch_size=8, shuffle=True, collate_fn=data_collator)

In [None]:
def trainer(dataloader, teacher_model, student_model, epochs=10):
    projection = torch.nn.Linear(student_model.config.hidden_size,
                                 2 * teacher_model.config.hidden_size).to('cuda')

    # Jointly optimise student model and projection - simplifcation compared to paper
    optimizer = torch.optim.AdamW(list(student_model.parameters()) + list(projection.parameters()), lr=5e-5)

    student_model.to('cuda')
    teacher_model.to('cuda')

    for epoch in range(epochs):
        student_model.train()
        running_loss = 0.0

        for batch in train_dataloader:
            input_ids = batch["input_ids"].to('cuda')
            attention_mask = batch["attention_mask"].to('cuda')

            # Forward pass through teacher (with no gradients).
            with torch.no_grad():
                teacher_outputs = teacher_model(input_ids=input_ids,
                                                attention_mask=attention_mask,
                                                output_hidden_states=True)
            student_outputs = student_model(input_ids=input_ids,
                                            attention_mask=attention_mask,
                                            output_hidden_states=True)

            # Extract hidden states - skip embedding
            teacher_hidden = teacher_outputs.hidden_states[1:] 
            student_hidden = student_outputs.hidden_states[1:]

            # For each student layer, project the full hidden state
            Hs = torch.stack([h for h in student_hidden])  # shape: (num_student_layers, batch, seq_len, hidden_size)
            Hs_proj = projection(Hs)  # shape: (num_student_layers, batch, seq_len, 2 * teacher_hidden_size)

            # Map teacher hidden states to student layers using a Uniform+Last strategy
            num_student_layers = len(student_hidden)
            teacher_layers = teacher_hidden
            num_teacher_layers = len(teacher_layers)
            Ht = []
            for i in range(num_student_layers):
                # Uniform mapping index (adjust indices if using 0-indexing):
                idx_uniform = int(i * num_teacher_layers / num_student_layers)
                # Last mapping index:
                idx_last = i + num_teacher_layers - num_student_layers

                H0 = teacher_layers[idx_uniform]
                H1 = teacher_layers[idx_last]
                # Concatenate along the hidden dimension.
                H_teacher_concat = torch.cat([H0, H1], dim=-1)  # shape: (batch, seq_len, 2*teacher_hidden_size)
                Ht.append(H_teacher_concat)

            # Stack teacher states to match student projection
            Ht = torch.stack(Ht)  # shape: (num_student_layers, batch, seq_len, 2 * teacher_hidden_size)

            loss = torch.nn.functional.mse_loss(Hs_proj, Ht)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        avg_loss = running_loss / len(dataloader)
        print(f"Epoch {epoch+1} | Avg Distillation Loss: {avg_loss:.4f}")

    return avg_loss

In [39]:
final_loss = trainer(train_dataloader, teacher_model, student_model, epochs=10)



Epoch 1 | Avg Distillation Loss: 0.6605
Epoch 2 | Avg Distillation Loss: 0.4494
Epoch 3 | Avg Distillation Loss: 0.4074
Epoch 4 | Avg Distillation Loss: 0.3861
Epoch 5 | Avg Distillation Loss: 0.3663
Epoch 6 | Avg Distillation Loss: 0.3512
Epoch 7 | Avg Distillation Loss: 0.3439
Epoch 8 | Avg Distillation Loss: 0.3336
Epoch 9 | Avg Distillation Loss: 0.3249
Epoch 10 | Avg Distillation Loss: 0.3242


In [40]:
def get_search_space():
    return {
        "num_hidden_layers": [3, 4, 6, 10, 12],
        "num_attention_heads": [2, 3, 4, 6, 12],
        "hidden_size": [384, 768],
        "intermediate_size": [384, 512, 576, 768, 1024, 1536, 2048, 3072],
        "hidden_act": ['gelu', 'relu', 'silu']
    }

search_space = get_search_space()
state_keys = list(search_space.keys())

In [None]:
from transformers import AlbertConfig, AlbertForMaskedLM
import numpy as np
import time

# Utils
def construct_student_model_from_config(config):
    new_config = AlbertConfig(**config)
    model = AlbertForMaskedLM(new_config)
    return model

def get_latency(config_or_model_config):
    """
    Measure average latency (in seconds) for a forward pass.
    Use CPU as expected deployment on CPU and will probs get more stable results
    """
    # construct model
    if isinstance(config_or_model_config, dict):
        model = construct_student_model_from_config(config_or_model_config)
    else:
        model = AlbertForMaskedLM(config_or_model_config)
    model.eval()
    model.to('cpu')

    # dummy input
    batch_size = 1
    seq_length = 50
    dummy_input = torch.randint(0, 100, (batch_size, seq_length))
    attention_mask = torch.ones_like(dummy_input)

    # warmup
    with torch.no_grad():
        for _ in range(5):
            _ = model(dummy_input, attention_mask=attention_mask)

    num_runs = 20
    start = time.time()
    with torch.no_grad():
        for _ in range(num_runs):
            _ = model(dummy_input, attention_mask=attention_mask)
    end = time.time()
    avg_latency = (end - start) / num_runs
    return avg_latency

def calculate_reward(loss, latency, teacher_latency, alpha=-0.06, beta=0.6**6):
    """
    reward function from paper:
    reward = (1 - L_HS) * (lat(S)/(beta * lat(T)))^alpha
    """
    normalized_latency = latency / (beta * teacher_latency)
    reward = (1 - loss) * (normalized_latency ** alpha)
    return reward

## LSTM Controller

- **Architecture**: Embedding layers for each hyperparameter choice, followed by an LSTM, and a linear layer to predict a scalar reward.
- **Input**: A sequence of three states (previous best, global best, current state), each represented as concatenated embeddings of hyperparameter indices.
- **Output**: Predicted reward for the current state.

In [50]:
import torch.nn as nn

# LSTM Controller
class Controller(nn.Module):
    def __init__(self, search_space, embedding_dim=32, lstm_hidden=32):
        super(Controller, self).__init__()
        self.embeddings = nn.ModuleDict()
        for param, choices in search_space.items():
            self.embeddings[param] = nn.Embedding(len(choices), embedding_dim)
        input_size = embedding_dim * len(search_space)
        self.lstm = nn.LSTM(input_size, lstm_hidden, batch_first=True)
        self.fc = nn.Linear(lstm_hidden, 1)
        self.state_keys = list(search_space.keys())

    def forward(self, prev_best, global_best, current_state):
        """Predict reward for current_state given prev_best and global_best."""
        prev_best_emb = torch.cat([self.embeddings[param](torch.tensor([idx], device='cuda'))
                                 for param, idx in zip(self.state_keys, prev_best)], dim=-1)
        global_best_emb = torch.cat([self.embeddings[param](torch.tensor([idx], device='cuda'))
                                   for param, idx in zip(self.state_keys, global_best)], dim=-1)
        current_state_emb = torch.cat([self.embeddings[param](torch.tensor([idx], device='cuda'))
                                     for param, idx in zip(self.state_keys, current_state)], dim=-1)
        sequence = torch.stack([prev_best_emb, global_best_emb, current_state_emb], dim=0)
        _, (h_n, _) = self.lstm(sequence.permute(1, 0, 2))
        reward_pred = self.fc(h_n[-1])
        return reward_pred.squeeze()

## Mini-KD Process
- **Proxy Set**: 30% of the training data, shuffled and selected once before the NAS loop
- **Training**: 4 epochs of KD, with hidden state mapping + projection
- **Output**: Average distillation loss of final epoch

In [51]:
def mini_kd_trainer(proxy_data, teacher_model, student_model, epochs=4):
    """Perform Mini-KD on proxy data and return average distillation loss."""
    proxy_dataloader = DataLoader(proxy_data, batch_size=32, shuffle=True, collate_fn=data_collator)
    projection = nn.Linear(student_model.config.hidden_size, 2 * teacher_model.config.hidden_size).to('cuda')

    optimizer = torch.optim.AdamW(list(student_model.parameters()) + list(projection.parameters()), lr=5e-5)
    student_model.to('cuda')
    teacher_model.to('cuda')

    for epoch in range(epochs):
        student_model.train()
        running_loss = 0.0
        for batch in proxy_dataloader:
            input_ids = batch["input_ids"].to('cuda')
            attention_mask = batch["attention_mask"].to('cuda')
            with torch.no_grad():
                teacher_outputs = teacher_model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
            student_outputs = student_model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)
            teacher_hidden = teacher_outputs.hidden_states[1:]
            student_hidden = student_outputs.hidden_states[1:]
            Hs = torch.stack([h for h in student_hidden])
            Hs_proj = projection(Hs)
            num_student_layers = len(student_hidden)
            num_teacher_layers = len(teacher_hidden)
            Ht = []
            for i in range(num_student_layers):
                idx_uniform = int(i * num_teacher_layers / num_student_layers)
                idx_last = i + num_teacher_layers - num_student_layers
                H_teacher_concat = torch.cat([teacher_hidden[idx_uniform], teacher_hidden[idx_last]], dim=-1)
                Ht.append(H_teacher_concat)
            Ht = torch.stack(Ht)
            loss = nn.functional.mse_loss(Hs_proj, Ht)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        avg_loss = running_loss / len(proxy_dataloader)
        print(f"Epoch {epoch+1}/{epochs} | Avg Distillation Loss: {avg_loss:.4f}")
    return avg_loss

In [47]:
def train_controller(controller, optimizer, training_data, prev_best, global_best, epochs=10):
    """Train the controller to predict rewards using MSE loss."""
    controller.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for state, reward in training_data:
            prev_best_tensor = torch.tensor(prev_best, device='cuda')
            global_best_tensor = torch.tensor(global_best, device='cuda')
            state_tensor = torch.tensor(state, device='cuda')
            reward_tensor = torch.tensor([reward], dtype=torch.float32, device='cuda')
            optimizer.zero_grad()
            reward_pred = controller(prev_best_tensor, global_best_tensor, state_tensor)
            # loss = nn.functional.mse_loss(reward_pred, reward_tensor)
            loss = nn.functional.mse_loss(reward_pred.unsqueeze(0), reward_tensor)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        avg_loss = running_loss / len(training_data)
    print(f"Controller Epoch {epochs} | Avg Loss: {avg_loss:.4f}")

In [54]:
# Hyperparams
M = 10  # Number of episodes. Complete run = 15
N = 5  # Number of candidates per episode. Complete run = 20
epsilon = 1.0
epsilon_min = 0.05
epsilon_decay = 0.05
pool_size = 100  # Size of state pool for controller selection

# Initialise controller
controller = Controller(search_space).to('cuda')
controller_optimizer = torch.optim.RMSprop(controller.parameters(), lr=1e-4)

# Initialise best states
random_state = [np.random.randint(len(search_space[key])) for key in state_keys]
global_best = {'state': random_state, 'reward': -float('inf')}
previous_best = {'state': random_state, 'reward': -float('inf')}

# Proxy dataset - 30% of original for shorter KD
proxy_size = int(0.3 * len(train_data))
proxy_data = train_data.shuffle().select(range(proxy_size))

teacher_latency = get_latency(teacher_model.config)

evaluated_states = []




## NAS Loop
- **Episodes**: M episodes, with N candidates per episode. Controller trained at the end of every episode.
- **Candidate Selection**: Controller samples a pool (e.g. 100 states), predicts rewards, and selects the top ones, random states fill the rest
- **Exploration/Exploitation**: Controlled by ε, starting at 1.0 and decaying to 0.05. Later episodes use more samples from controller as it improves from training. See paper
- **Evaluation**: Mini-KD provides loss and latency for reward calculation
- **Controller Training**: Uses MSE loss between predicted and actual rewards, incorporating all candidate states plus global/previous best states

TODO:
- Cache architecture rewards, so if duplicate sample, then can just fetch reward and avoid KD again

In [None]:
# NAS Episodes
for episode in range(M):
    print(f"\nEpisode {episode+1}/{M}, Exploration Ratio: {epsilon:.2f}")
    num_random = int(epsilon * N)
    num_controller = N - num_random
    candidate_states = []

    # Exploitation: Controller predicts high reward states
    if num_controller > 0:
        pool_states = [[np.random.randint(len(search_space[key])) for key in state_keys]
                      for _ in range(pool_size)]
        predicted_rewards = []
        for state in pool_states:
            with torch.no_grad():
                reward_pred = controller(previous_best['state'], global_best['state'], state)
            predicted_rewards.append(reward_pred.item())
        indices = np.argsort(predicted_rewards)[-num_controller:]
        for idx in indices:
            candidate_states.append(pool_states[idx])

    # Exploration: Random states
    for _ in range(num_random):
        state = [np.random.randint(len(search_space[key])) for key in state_keys]
        candidate_states.append(state)

    # Evaluate candidates
    episode_rewards = []
    for state_idx, state in enumerate(candidate_states):
        config = {key: search_space[key][idx] for key, idx in zip(state_keys, state)}
        print(f"\n  Candidate {state_idx+1}/{N} Configuration:")
        for key, value in config.items():
            print(f"    {key}: {value}")
        student_model = construct_student_model_from_config(config)
        loss = mini_kd_trainer(proxy_data, teacher_model, student_model)
        latency = get_latency(config)
        reward = calculate_reward(loss, latency, teacher_latency)
        episode_rewards.append(reward)
        evaluated_states.append({'state': state, 'config': config, 'reward': reward})
        print(f"    Mini-KD Loss: {loss:.4f}, Latency: {latency:.4f}, Reward: {reward:.4f}")

    # Update best states
    best_idx = np.argmax(episode_rewards)
    episode_best_state = candidate_states[best_idx]
    episode_best_reward = episode_rewards[best_idx]
    previous_best = {'state': episode_best_state, 'reward': episode_best_reward}
    if episode_best_reward > global_best['reward']:
        global_best = {'state': episode_best_state, 'reward': episode_best_reward}

    # Train controller
    training_data = list(zip(candidate_states, episode_rewards))
    training_data.append((global_best['state'], global_best['reward']))
    training_data.append((previous_best['state'], previous_best['reward']))
    train_controller(controller, controller_optimizer, training_data,
                    previous_best['state'], global_best['state'])

    # Decay exploration ratio
    epsilon = max(epsilon_min, epsilon - epsilon_decay)


Episode 1/10, Exploration Ratio: 1.00

  Candidate 1/5 Configuration:
    num_hidden_layers: 4
    num_attention_heads: 3
    hidden_size: 768
    intermediate_size: 1024
    hidden_act: gelu
Epoch 1/4 | Avg Distillation Loss: 0.8565
Epoch 2/4 | Avg Distillation Loss: 0.7207
Epoch 3/4 | Avg Distillation Loss: 0.5882
Epoch 4/4 | Avg Distillation Loss: 0.5430
    Mini-KD Loss: 0.5430, Latency: 0.0137, Reward: 0.4231

  Candidate 2/5 Configuration:
    num_hidden_layers: 6
    num_attention_heads: 2
    hidden_size: 768
    intermediate_size: 576
    hidden_act: gelu
Epoch 1/4 | Avg Distillation Loss: 0.8406
Epoch 2/4 | Avg Distillation Loss: 0.6801
Epoch 3/4 | Avg Distillation Loss: 0.5752
Epoch 4/4 | Avg Distillation Loss: 0.5433
    Mini-KD Loss: 0.5433, Latency: 0.0161, Reward: 0.4188

  Candidate 3/5 Configuration:
    num_hidden_layers: 6
    num_attention_heads: 3
    hidden_size: 768
    intermediate_size: 384
    hidden_act: gelu
Epoch 1/4 | Avg Distillation Loss: 0.8408
Epoch 2

  loss = nn.functional.mse_loss(reward_pred, reward_tensor)


Controller Epoch 6/10 | Avg Loss: 0.0001
Controller Epoch 7/10 | Avg Loss: 0.0001
Controller Epoch 8/10 | Avg Loss: 0.0000
Controller Epoch 9/10 | Avg Loss: 0.0000
Controller Epoch 10/10 | Avg Loss: 0.0000

Episode 2/10, Exploration Ratio: 0.95

  Candidate 1/5 Configuration:
    num_hidden_layers: 6
    num_attention_heads: 12
    hidden_size: 384
    intermediate_size: 768
    hidden_act: gelu
Epoch 1/4 | Avg Distillation Loss: 0.9244
Epoch 2/4 | Avg Distillation Loss: 0.7673
Epoch 3/4 | Avg Distillation Loss: 0.7477
Epoch 4/4 | Avg Distillation Loss: 0.6742
    Mini-KD Loss: 0.6742, Latency: 0.0115, Reward: 0.3048

  Candidate 2/5 Configuration:
    num_hidden_layers: 3
    num_attention_heads: 12
    hidden_size: 768
    intermediate_size: 1536
    hidden_act: relu
Epoch 1/4 | Avg Distillation Loss: 0.8372
Epoch 2/4 | Avg Distillation Loss: 0.7284
Epoch 3/4 | Avg Distillation Loss: 0.5953
Epoch 4/4 | Avg Distillation Loss: 0.5195
    Mini-KD Loss: 0.5195, Latency: 0.0120, Reward: 0

In [59]:
# Select top 3 architectures
evaluated_states.sort(key=lambda x: x['reward'], reverse=True)
top_3 = evaluated_states[:3]
print("\nTop 3 Architectures:")
for i, arch in enumerate(top_3, 1):
    print(f"{i}. State: {arch['state']}, Config: {arch['config']}, Reward: {arch['reward']:.4f}")


Top 3 Architectures:
1. State: [1, 0, 1, 6, 1], Config: {'num_hidden_layers': 4, 'num_attention_heads': 2, 'hidden_size': 768, 'intermediate_size': 2048, 'hidden_act': 'relu'}, Reward: 0.4509
2. State: [0, 3, 1, 5, 1], Config: {'num_hidden_layers': 3, 'num_attention_heads': 6, 'hidden_size': 768, 'intermediate_size': 1536, 'hidden_act': 'relu'}, Reward: 0.4488
3. State: [0, 4, 1, 5, 1], Config: {'num_hidden_layers': 3, 'num_attention_heads': 12, 'hidden_size': 768, 'intermediate_size': 1536, 'hidden_act': 'relu'}, Reward: 0.4484
