# Compute forgetting curves and model weight composition

In [None]:
import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm

# disable type 3 fonts
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42

# increase font size
plt.rcParams.update({'font.size': 18})

# smaller font for axis ticks
plt.rcParams.update({'xtick.labelsize': 14})
plt.rcParams.update({'ytick.labelsize': 14})

In [None]:
# cosine learning rate schedule
def cosine_lr_schedule(step, total_steps, warmup_steps=700, max_lr=6e-4, min_lr=6e-5):
    """Returns the learning rate at step."""
    if step < warmup_steps:
        return max_lr * step / warmup_steps
    return min_lr + 0.5 * (max_lr - min_lr) * (1 + np.cos((step - warmup_steps) / (total_steps - warmup_steps) * np.pi))

steps = 47683
lrs = [cosine_lr_schedule(step, steps) for step in range(steps)]
plt.plot(lrs, linestyle='-', color='b')

In [None]:
# linear learning rate schedule
def linear_lr_schedule(step, total_steps, warmup_steps=700, max_lr=6e-4, min_lr=6e-5):
    """Returns the learning rate at step."""
    if step < warmup_steps:
        return max_lr * step / warmup_steps
    return min_lr + (max_lr - min_lr) * (1 - (step - warmup_steps) / (total_steps - warmup_steps))

steps = 47683
lrs = [linear_lr_schedule(step, steps) for step in range(steps)]
plt.plot(lrs, linestyle='-', color='b')

In [21]:
def get_adamw_decay_factors(step_start, step_end, total_steps, weight_decay, warmup_steps=700, max_lr=6e-4, min_lr=6e-5, get_lr = cosine_lr_schedule):
    """returns the factors by how much the weights are decayed from start step to end step."""
    per_step_decay_factors = []
    for step in range(step_start, step_end):
        lr = get_lr(step, total_steps, warmup_steps, max_lr, min_lr)
        per_step_decay_factors.append(1 - lr * weight_decay)
    return np.cumprod(per_step_decay_factors)

def model_composition_plot(model_norm):
    # Use the default matplotlib color cycle
    colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

    # Adjust the figure size for a long bar (width-to-height ratio)
    fig, ax = plt.subplots(figsize=(10, 1))  # width=10, height=1 for a long, thin bar

    # Plot a single horizontal bar composed of the segments
    ax.barh(0, model_norm[0], color=colors[0], edgecolor='black')
    for i in range(1, len(model_norm)):
        ax.barh(0, model_norm[i], left=np.sum(model_norm[:i]), color=colors[i], edgecolor='black')

    # Remove axes
    ax.axis('off')

### 124M 1x

In [None]:
max_lr = 6.0E-4
min_lr = 6.0E-5
warmup_steps = 700
weight_decay = 0.1
total_steps = 4730

# plot the decay after 0%; 10%, 20%, 30%, 40%, 50%, 60%, 70%, 80%, 90% of the steps, all in one plot with one line per percentage
for i in range(0, 100, 10):
    step_start = int(i / 100 * total_steps)
    decay_factor = get_adamw_decay_factors(step_start, total_steps, total_steps, weight_decay, warmup_steps, max_lr, min_lr, get_lr=cosine_lr_schedule)
    x = np.arange(step_start, total_steps)
    plt.plot(x, decay_factor, linestyle='-')
plt.ylim(0, 1)
plt.xlim(0, total_steps)
# plt.legend(["0%", "10%", "20%", "30%", "40%", "50%", "60%", "70%", "80%", "90%"])
plt.xlabel("Gradient Step")
plt.ylabel("Cumulative Weight Decay Factor")

# save the plot
plt.savefig("figures/decays/124_1x_Chinchilla.pdf")
plt.show()

In [None]:
model_norm = []
for i in range(0, 100, 10):
    step_start = int(i / 100 * total_steps)
    step_end = int((i+10) / 100 * total_steps) # we look at 10% of the steps
    # now sum all the learnig rates from step_start to step_end
    lr_sum = 0
    for step in range(step_start, step_end):
        lr_sum += cosine_lr_schedule(step, total_steps, warmup_steps, max_lr, min_lr)
    # approximate the decay by using the middle step of the interval
    decay_factor = get_adamw_decay_factors(int((step_start+step_end)/2), total_steps, total_steps, weight_decay, warmup_steps, max_lr, min_lr, get_lr=cosine_lr_schedule)
    print(f"Step {step_start} - Step {step_end},  {lr_sum:.3f}, {decay_factor[-1]:.3f}")
    #print(f"Decay factor: {decay_factor[-1]}")
    model_norm.append(decay_factor[-1] * lr_sum)
model_norm = np.array(model_norm) / np.sum(model_norm)
print(model_norm)

# make a chart with bars of the right size for each decile of model norm
model_composition_plot(model_norm)
plt.savefig("figures/decays/124_1x_Chinchilla_composition.pdf")


### 124M 15x

In [None]:
max_lr = 6.0E-4
min_lr = 6.0E-5
warmup_steps = 700
weight_decay = 0.1
total_steps = 47000 // 10 * 15

# plot the decay after 0%; 10%, 20%, 30%, 40%, 50%, 60%, 70%, 80%, 90% of the steps, all in one plot with one line per percentage
for i in range(0, 100, 10):
    step_start = int(i / 100 * total_steps)
    decay_factor = get_adamw_decay_factors(step_start, total_steps, total_steps, weight_decay, warmup_steps, max_lr, min_lr, get_lr=cosine_lr_schedule)
    x = np.arange(step_start, total_steps)
    plt.plot(x, decay_factor, linestyle='-')
plt.ylim(0, 1)
plt.xlim(0, total_steps)
# plt.title("124M 15x Chinchilla Cumulative Weight Decay")
 # plt.legend(["0%", "10%", "20%", "30%", "40%", "50%", "60%", "70%", "80%", "90%"])
plt.xlabel("Gradient Step")
plt.ylabel("Cumulative Weight Decay")

# save figure
plt.savefig("figures/124M_15x_wd.pdf")
plt.show()

In [None]:
# make a figure that tell us how the final model weights are composed of the gradient updates seen at different steps
max_lr = 6.0E-4
min_lr = 6.0E-5
warmup_steps = 700
weight_decay = 0.1
total_steps = 47000 // 10 * 15
model_norm = []
for i in range(0, 100, 10):
    step_start = int(i / 100 * total_steps)
    step_end = int((i+10) / 100 * total_steps) # we look at 10% of the steps
    # now sum all the learnig rates from step_start to step_end
    lr_sum = 0
    for step in range(step_start, step_end):
        lr_sum += cosine_lr_schedule(step, total_steps, warmup_steps, max_lr, min_lr)
    # approximate the decay by using the middle step of the interval
    decay_factor = get_adamw_decay_factors(int((step_start+step_end)/2), total_steps, total_steps, weight_decay, warmup_steps, max_lr, min_lr, get_lr=cosine_lr_schedule)
    print(f"Step {step_start} - Step {step_end},  {lr_sum:.3f}, {decay_factor[-1]:.3f}")
    #print(f"Decay factor: {decay_factor[-1]}")
    model_norm.append(decay_factor[-1] * lr_sum)
model_norm = np.array(model_norm) / np.sum(model_norm)
print(model_norm)

# make a chart with bars of the right size for each decile of model norm
model_composition_plot(model_norm)

plt.savefig("figures/124M_15x_model_composition.pdf", bbox_inches='tight')

### 1.6B 1x

In [None]:
max_lr = 2.0E-4
min_lr = 2.0E-5
warmup_steps = 700
weight_decay = 0.1
total_steps = 30803

# plot the decay after 0%; 10%, 20%, 30%, 40%, 50%, 60%, 70%, 80%, 90% of the steps, all in one plot with one line per percentage
for i in range(0, 100, 10):
    step_start = int(i / 100 * total_steps)
    decay_factor = get_adamw_decay_factors(step_start, total_steps, total_steps, weight_decay, warmup_steps, max_lr, min_lr, get_lr=cosine_lr_schedule)
    x = np.arange(step_start, total_steps)
    plt.plot(x, decay_factor, linestyle='-')
plt.ylim(0, 1)
plt.xlim(0, total_steps)
# plt.title("1.6B 1x Chinchilla Cumulative Weight Decay")
#plt.legend(["0%", "10%", "20%", "30%", "40%", "50%", "60%", "70%", "80%", "90%"])
plt.xlabel("Gradient Step")
plt.ylabel("Cumulative Weight Decay Factor")
# save figure
plt.savefig("figures/decays/1.6B_1x_Chinchilla.pdf")
plt.show()

model_norm = []
for i in range(0, 100, 10):
    step_start = int(i / 100 * total_steps)
    step_end = int((i+10) / 100 * total_steps) # we look at 10% of the steps
    # now sum all the learnig rates from step_start to step_end
    lr_sum = 0
    for step in range(step_start, step_end):
        lr_sum += cosine_lr_schedule(step, total_steps, warmup_steps, max_lr, min_lr)
    # approximate the decay by using the middle step of the interval
    decay_factor = get_adamw_decay_factors(int((step_start+step_end)/2), total_steps, total_steps, weight_decay, warmup_steps, max_lr, min_lr, get_lr=cosine_lr_schedule)
    print(f"Step {step_start} - Step {step_end},  {lr_sum:.3f}, {decay_factor[-1]:.3f}")
    #print(f"Decay factor: {decay_factor[-1]}")
    model_norm.append(decay_factor[-1] * lr_sum)
model_norm = np.array(model_norm) / np.sum(model_norm)
print(model_norm)

# make a chart with bars of the right size for each decile of model norm
model_composition_plot(model_norm)
plt.savefig("figures/decays/1.6B_1x_Chinchilla_composition.pdf")

## OLMo 7B

In [None]:
max_lr = 3.0E-4
min_lr = 3.0E-5
warmup_steps = 5000
weight_decay = 0.1
total_steps = 615000

# plot the decay after 0%; 10%, 20%, 30%, 40%, 50%, 60%, 70%, 80%, 90% of the steps, all in one plot with one line per percentage
for i in range(0, 100, 10):
    step_start = int(i / 100 * total_steps)
    decay_factor = get_adamw_decay_factors(step_start, total_steps, total_steps, weight_decay, warmup_steps, max_lr, min_lr, get_lr=linear_lr_schedule)
    x = np.arange(step_start, total_steps)
    plt.plot(x, decay_factor, linestyle='-')
plt.ylim(0, 1)
plt.xlim(0, total_steps)
# plt.title("OLMo 7B Cumulative Weight Decay")
# plt.legend(["0%", "10%", "20%", "30%", "40%", "50%", "60%", "70%", "80%", "90%"])
plt.xlabel("Gradient Step")
plt.ylabel("Cumulative Weight Decay")

# x axis steps
x = np.arange(0, total_steps, 200000)
plt.xticks(x, x)

# save figure
plt.savefig("figures/OLMo_7B_wd.pdf")
plt.show()

## Llama 405B

#### "We pre-train Llama 3 405B using AdamW with a peak learning rate of 8 × 10−5 , a linear warm up of 8,000 steps, and a cosine learning rate schedule decaying to 8 × 10−7 over 1,200,000 steps"

In [None]:
max_lr = 8e-5
min_lr = 8e-7
warmup_steps = 8000
weight_decay = 0.1
total_steps = 1200000

# plot the decay after 0%; 10%, 20%, 30%, 40%, 50%, 60%, 70%, 80%, 90% of the steps, all in one plot with one line per percentage
for i in range(0, 100, 10):
    step_start = int(i / 100 * total_steps)
    decay_factor = get_adamw_decay_factors(step_start, total_steps, total_steps, weight_decay, warmup_steps, max_lr, min_lr, get_lr=cosine_lr_schedule)
    x = np.arange(step_start, total_steps)
    plt.plot(x, decay_factor, linestyle='-')
plt.ylim(0, 1)
plt.xlim(0, total_steps)
# plt.title("Llama 3 405B Cumulative Weight Decay (assumes wd=0.1)")
# plt.legend(["0%", "10%", "20%", "30%", "40%", "50%", "60%", "70%", "80%", "90%"])
plt.xlabel("Gradient Step")

plt.ylabel("Cumulative Weight Decay")

#x ticks
x = np.arange(0, total_steps, 200000)
plt.xticks(x, x)

# save figure
plt.savefig("figures/LLama3_405B_wd.pdf")
plt.show()