In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from src.badam import BlockOptimizer

  from .autonotebook import tqdm as notebook_tqdm


In [113]:
class Net(nn.Module):
    
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(10, 1)
        
    def forward(self, x):
        x = self.fc1(x)
        #print(f'Output after fc1: {x}')
        x = self.relu(x)
        #print(f'Output after ReLU: {x}')
        x = self.fc2(x)
        #print(f'Output after fc2: {x}')
        return x

model = Net().cuda()  
criterion = nn.MSELoss()
optimizer = adam_optimizer

model.train()
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    
    optimizer.zero_grad()  
    
    outputs = model(x_train)
    loss = criterion(outputs, y_train)
    
    loss.backward()  
    
    for name, param in model.named_parameters():
        if param.grad is not None:
            print(f'Gradient for {name}')
    
    optimizer.step()  
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

Gradient for fc1.weight
Gradient for fc1.bias
Gradient for fc2.weight
Gradient for fc2.bias
Epoch [1/10], Loss: 1.2658
Gradient for fc1.weight
Gradient for fc1.bias
Gradient for fc2.weight
Gradient for fc2.bias
Epoch [2/10], Loss: 1.2658
Gradient for fc1.weight
Gradient for fc1.bias
Gradient for fc2.weight
Gradient for fc2.bias
Epoch [3/10], Loss: 1.2658
Gradient for fc1.weight
Gradient for fc1.bias
Gradient for fc2.weight
Gradient for fc2.bias
Epoch [4/10], Loss: 1.2658
Gradient for fc1.weight
Gradient for fc1.bias
Gradient for fc2.weight
Gradient for fc2.bias
Epoch [5/10], Loss: 1.2658
Gradient for fc1.weight
Gradient for fc1.bias
Gradient for fc2.weight
Gradient for fc2.bias
Epoch [6/10], Loss: 1.2658
Gradient for fc1.weight
Gradient for fc1.bias
Gradient for fc2.weight
Gradient for fc2.bias
Epoch [7/10], Loss: 1.2658
Gradient for fc1.weight
Gradient for fc1.bias
Gradient for fc2.weight
Gradient for fc2.bias
Epoch [8/10], Loss: 1.2658
Gradient for fc1.weight
Gradient for fc1.bias
Gr

In [108]:
# 0. setup
device = 'cuda:0'

x_train = torch.randn(64, 10)
y_train = torch.randn(64, 1)

x_train = x_train.to(device)
y_train = y_train.to(device)

In [105]:
# 1. Initialize a simple fc model
class Model(nn.Module):
    
    def __init__(self):
        super(Model, self).__init__()
        self.fc1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(10, 1)
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x
    
    
model = Model().cuda()

In [137]:
# 2. Initialize the BAdam Optimizer
adam_optimizer = optim.Adam(model.parameters(), lr=0.001)
badam_optimizer = BlockOptimizer(
    base_optimizer = adam_optimizer,
    named_parameters_list = list(model.named_parameters()),
    switch_block_every = 10,
    switch_mode='ascending',
    verbose=2,
    block_prefix_list = block_prefix_list
)

badam_optimizer

Parameters with the following prefix will be trainable: ['fc1.weight']
fc1.weight


BlockOptimizer (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.001
    maximize: False
    weight_decay: 0
)

In [43]:
# 3. Define the loss function
criterion = nn.MSELoss()

In [127]:
# 4. CUDA Automatic Mixed Prercision https://pytorch.org/docs/stable/notes/amp_examples.html#amp-examples
from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

In [138]:
model.train()
num_epochs = 20
for epoch in range(num_epochs):
    outputs = model(x_train)
    loss = criterion(outputs, y_train)
    
    badam_optimizer.zero_grad() 
    loss.backward()
    
    for name, param in model.named_parameters():
        if param.grad is not None:
            print(f'name {name}')
    
    badam_optimizer.step()
    
    
    print(f'loss: {loss.item():.4f} type {loss.dtype}')

name fc1.weight
loss: 1.1344 type torch.float32
name fc1.weight
loss: 1.1321 type torch.float32
name fc1.weight
loss: 1.1298 type torch.float32
name fc1.weight
loss: 1.1277 type torch.float32
name fc1.weight
loss: 1.1255 type torch.float32
name fc1.weight
loss: 1.1234 type torch.float32
name fc1.weight
loss: 1.1213 type torch.float32
name fc1.weight
loss: 1.1192 type torch.float32
name fc1.weight
Parameters with the following prefix will be trainable: ['fc1.bias']
fc1.bias
loss: 1.1171 type torch.float32
name fc1.bias
loss: 1.1151 type torch.float32
name fc1.bias
loss: 1.1146 type torch.float32
name fc1.bias
loss: 1.1142 type torch.float32
name fc1.bias
loss: 1.1138 type torch.float32
name fc1.bias
loss: 1.1134 type torch.float32
name fc1.bias
loss: 1.1130 type torch.float32
name fc1.bias
loss: 1.1126 type torch.float32
name fc1.bias
loss: 1.1122 type torch.float32
name fc1.bias
loss: 1.1119 type torch.float32
name fc1.bias
Parameters with the following prefix will be trainable: ['fc2.

In [130]:
# 5. use half precision/mis precision
scaler = GradScaler()

num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    with autocast():
        outputs = model(x_train)
        #print(f'output precision {outputs.dtype}')
        loss = criterion(outputs, y_train)
        #print(f'loss precision {loss.dtype} loss before scale {loss.item():.4f}')
    
    badam_optimizer.zero_grad()
    scaler.scale(loss).backward()
    scaler.step(badam_optimizer)
    scaler.update()
    