In [1]:
import torch
import torch.nn as nn
from tqdm import tqdm

from dataloaders.MNISTDataset import get_dataloaders
from nn.Net import Net
from nn.ControlNet import ControlNet
from utils.colored_prints import *

In [2]:
train_loader, test_loader = get_dataloaders(0)


In [3]:
num_epochs = 50
inner_epochs = 156
learning_rate = 0.001
control_lr = 0.001
control_threshold = 1e-3
l1_lambda = 0.0


input_size_net = 28 * 28
hidden_size_net = 100
output_size_net = 10
hidden_size_control = 100

input_size_control = input_size_net + hidden_size_net + output_size_net

In [4]:
net = Net(
    input_size=input_size_net,
    hidden_size=hidden_size_net,
    output_size=output_size_net,
    softmax=False,
)

In [5]:
control_net = ControlNet(
    input_size=input_size_control,
    hidden_size=hidden_size_control,
    output_size=hidden_size_net + output_size_net,
)

In [6]:
criterion = nn.CrossEntropyLoss()
control_optimizer = torch.optim.Adam(control_net.parameters(), lr=float(control_lr))
net_optimizer = torch.optim.Adam(net.parameters(), lr=float(learning_rate))

In [8]:
def evaluate_model(net, control_net, test_loader):
    correct = 0
    total = 0

    with torch.no_grad():
        for batch_data, batch_labels in test_loader:
            net.reset_control_signals()
            h1 = net.layer1(net.flatten(batch_data))
            output = net(batch_data)
            current_activities = torch.cat([net.flatten(batch_data), h1, output], dim=1)

            control_signals = control_net(current_activities)
            net.set_control_signals(control_signals)
            output = net(batch_data)

            predictions = output.max(dim=1).indices
            total += batch_labels.size(0)
            correct += (predictions == batch_labels).sum().item()

    accuracy = 100 * correct / total

    return accuracy

In [9]:
def train_model(
    net,
    control_net,
    train_loader,
    criterion,
    control_optimizer,
    net_optimizer,
    control_threshold,
    l1_lambda,
    verbose = False
):
    pbar = tqdm(range(num_epochs), desc=f"Epochs", leave=False)

    for epoch in pbar:
        batch_losses = []

        for batch_data, batch_labels in train_loader:
            inner_epoch_correct = None
            inner_epoch_cvg = None

            # Get current network activities
            with torch.no_grad():
                net.reset_control_signals()
                h1 = net.layer1(net.flatten(batch_data))
                output = net(batch_data)
                current_activities = torch.cat(
                    [net.flatten(batch_data), h1, output], dim=1
                )

            old_loss = float("inf")
            for inner_epoch in range(100):
                control_optimizer.zero_grad()
                net_optimizer.zero_grad()  # TODO: Do I need this?

                control_signals = control_net(current_activities)
                net.set_control_signals(control_signals)

                output = net(batch_data)  # net is excluded from the graph

                # hardcoded label
                control_loss = criterion(output, batch_labels)

                # l1_reg = l1_lambda * (net(data) - label).abs().sum(dim=1).mean()

                control_loss.backward()

                control_optimizer.step()
                net_optimizer.step()

                if (
                    torch.argmax(output, dim=1) == batch_labels
                ).all() and inner_epoch_correct is None:
                    inner_epoch_correct = inner_epoch

                if abs(old_loss - control_loss.item()) < control_threshold:
                    inner_epoch_cvg = inner_epoch
                    break

                old_loss = control_loss.item()

            acc = (
                torch.sum(torch.argmax(output, dim=1) == batch_labels).item()
                / batch_labels.size(0)
                * 100
            )

            if verbose:
                if acc < 80:
                    print_error(f"Fail! {acc:.2f}% at inner_epoch {inner_epoch}")
                else:
                    print_info(f"Win {acc:.2f}% at inner_epoch {inner_epoch}")

                if inner_epoch_cvg is None:
                    print(f"Failed to converge")
                else:
                    print(f"Converged at inner epoch {inner_epoch_cvg}")

                print("\n")

            if control_loss.item() > 0.01:
                batch_losses.append(control_loss.item())
                with torch.no_grad():
                    control_signals = control_net(current_activities)
                    # a.shape is [batch_size, hidden_size + output_size]
                    a1 = control_signals[:, : net.hidden_size]
                    a2 = control_signals[:, net.hidden_size :]

                    # Sander said, we can use 1.0 as the baseline
                    baseline_a1 = torch.ones_like(a1)
                    baseline_a2 = torch.ones_like(a2)
                    a1_diff = a1 - baseline_a1
                    a2_diff = a2 - baseline_a2

                    # Layer 1 weight update
                    x = net.flatten(batch_data)
                    phi = net.hidden_activations(net.layer1(x))
                    r_post_adjusted = phi * a1 * a1_diff
                    dw = r_post_adjusted.T @ x
                    dw = dw / x.shape[0]
                    net.layer1.weight.grad = torch.clamp(dw, min=-1, max=1)

                    # Layer 2 weight update
                    x2 = net.hidden_activations(net.layer1(net.flatten(batch_data)))
                    phi2 = net.output_activations(net.layer2(x2))
                    r_post_adjusted2 = phi2 * a2 * a2_diff
                    dw2 = r_post_adjusted2.T @ x2
                    dw2 = dw2 / x2.shape[0]
                    net.layer2.weight.grad = torch.clamp(dw2, min=-1, max=1)

                    net_optimizer.step()
        epoch_loss = sum(batch_losses) / len(batch_losses) if batch_losses else 0
        accuracy = evaluate_model(net, control_net, test_loader)
        print(f"Epoch {epoch}  Loss: {epoch_loss}  Accuracy: {accuracy:.2f}%")

In [10]:
train_model(
    net,
    control_net,
    train_loader,
    criterion,
    control_optimizer,
    net_optimizer,
    control_threshold,
    l1_lambda,
)

Epochs:   0%|          | 0/50 [00:00<?, ?it/s]

[92m[INFO][0m Win 100.00% at inner_epoch 28
Converged at inner epoch 28


[92m[INFO][0m Win 100.00% at inner_epoch 26
Converged at inner epoch 26


[92m[INFO][0m Win 100.00% at inner_epoch 20
Converged at inner epoch 20


[92m[INFO][0m Win 100.00% at inner_epoch 53
Converged at inner epoch 53


[92m[INFO][0m Win 100.00% at inner_epoch 58
Converged at inner epoch 58


[92m[INFO][0m Win 100.00% at inner_epoch 51
Converged at inner epoch 51


[92m[INFO][0m Win 100.00% at inner_epoch 65
Converged at inner epoch 65


[92m[INFO][0m Win 100.00% at inner_epoch 51
Converged at inner epoch 51


[92m[INFO][0m Win 100.00% at inner_epoch 60
Converged at inner epoch 60


[92m[INFO][0m Win 100.00% at inner_epoch 55
Converged at inner epoch 55


[92m[INFO][0m Win 90.62% at inner_epoch 7
Converged at inner epoch 7


[92m[INFO][0m Win 100.00% at inner_epoch 68
Converged at inner epoch 68


[92m[INFO][0m Win 100.00% at inner_epoch 42
Converged at inner epoch 42


[92m[INFO][0m

Epochs:   2%|▏         | 1/50 [04:12<3:26:05, 252.36s/it]

Epoch 0  Loss: 0.13572101545785287  Accuracy: 81.65%
[92m[INFO][0m Win 100.00% at inner_epoch 25
Converged at inner epoch 25


[92m[INFO][0m Win 100.00% at inner_epoch 9
Converged at inner epoch 9


[92m[INFO][0m Win 100.00% at inner_epoch 18
Converged at inner epoch 18


[92m[INFO][0m Win 100.00% at inner_epoch 11
Converged at inner epoch 11


[92m[INFO][0m Win 100.00% at inner_epoch 22
Converged at inner epoch 22


[92m[INFO][0m Win 100.00% at inner_epoch 16
Converged at inner epoch 16


[92m[INFO][0m Win 100.00% at inner_epoch 4
Converged at inner epoch 4


[92m[INFO][0m Win 100.00% at inner_epoch 16
Converged at inner epoch 16


[92m[INFO][0m Win 100.00% at inner_epoch 3
Converged at inner epoch 3


[92m[INFO][0m Win 100.00% at inner_epoch 9
Converged at inner epoch 9


[92m[INFO][0m Win 100.00% at inner_epoch 33
Converged at inner epoch 33


[92m[INFO][0m Win 100.00% at inner_epoch 23
Converged at inner epoch 23


[92m[INFO][0m Win 100.00% at inner_epoch 

Epochs:   4%|▍         | 2/50 [09:42<3:58:27, 298.06s/it]

Epoch 1  Loss: 0.8680063183031356  Accuracy: 90.40%
[92m[INFO][0m Win 100.00% at inner_epoch 18
Converged at inner epoch 18


[92m[INFO][0m Win 100.00% at inner_epoch 20
Converged at inner epoch 20


[92m[INFO][0m Win 100.00% at inner_epoch 17
Converged at inner epoch 17


[92m[INFO][0m Win 100.00% at inner_epoch 3
Converged at inner epoch 3


[92m[INFO][0m Win 100.00% at inner_epoch 28
Converged at inner epoch 28


[92m[INFO][0m Win 100.00% at inner_epoch 1
Converged at inner epoch 1


[92m[INFO][0m Win 100.00% at inner_epoch 11
Converged at inner epoch 11


[92m[INFO][0m Win 100.00% at inner_epoch 1
Converged at inner epoch 1


[92m[INFO][0m Win 100.00% at inner_epoch 16
Converged at inner epoch 16


[92m[INFO][0m Win 100.00% at inner_epoch 22
Converged at inner epoch 22


[92m[INFO][0m Win 100.00% at inner_epoch 25
Converged at inner epoch 25


[92m[INFO][0m Win 100.00% at inner_epoch 17
Converged at inner epoch 17


[92m[INFO][0m Win 100.00% at inner_epoch

Epochs:   6%|▌         | 3/50 [12:51<3:14:20, 248.10s/it]

Epoch 2  Loss: 1.1244960280600935  Accuracy: 92.65%
[92m[INFO][0m Win 100.00% at inner_epoch 13
Converged at inner epoch 13


[92m[INFO][0m Win 100.00% at inner_epoch 6
Converged at inner epoch 6


[92m[INFO][0m Win 100.00% at inner_epoch 1
Converged at inner epoch 1


[92m[INFO][0m Win 100.00% at inner_epoch 6
Converged at inner epoch 6


[92m[INFO][0m Win 100.00% at inner_epoch 20
Converged at inner epoch 20


[92m[INFO][0m Win 100.00% at inner_epoch 17
Converged at inner epoch 17


[92m[INFO][0m Win 100.00% at inner_epoch 7
Converged at inner epoch 7


[92m[INFO][0m Win 93.75% at inner_epoch 5
Converged at inner epoch 5


[92m[INFO][0m Win 100.00% at inner_epoch 11
Converged at inner epoch 11


[92m[INFO][0m Win 100.00% at inner_epoch 10
Converged at inner epoch 10


[92m[INFO][0m Win 100.00% at inner_epoch 23
Converged at inner epoch 23


[92m[INFO][0m Win 100.00% at inner_epoch 12
Converged at inner epoch 12


[92m[INFO][0m Win 100.00% at inner_epoch 12
C

Epochs:   8%|▊         | 4/50 [15:41<2:46:50, 217.61s/it]

Epoch 3  Loss: 3.292930863224543  Accuracy: 91.54%
[92m[INFO][0m Win 100.00% at inner_epoch 22
Converged at inner epoch 22


[92m[INFO][0m Win 100.00% at inner_epoch 5
Converged at inner epoch 5


[92m[INFO][0m Win 100.00% at inner_epoch 8
Converged at inner epoch 8


[92m[INFO][0m Win 100.00% at inner_epoch 10
Converged at inner epoch 10


[92m[INFO][0m Win 100.00% at inner_epoch 9
Converged at inner epoch 9


[92m[INFO][0m Win 100.00% at inner_epoch 1
Converged at inner epoch 1


[92m[INFO][0m Win 100.00% at inner_epoch 5
Converged at inner epoch 5


[92m[INFO][0m Win 100.00% at inner_epoch 14
Converged at inner epoch 14


[92m[INFO][0m Win 100.00% at inner_epoch 3
Converged at inner epoch 3


[92m[INFO][0m Win 100.00% at inner_epoch 9
Converged at inner epoch 9


[92m[INFO][0m Win 100.00% at inner_epoch 12
Converged at inner epoch 12


[92m[INFO][0m Win 100.00% at inner_epoch 13
Converged at inner epoch 13


[92m[INFO][0m Win 100.00% at inner_epoch 24
Conve

Epochs:  10%|█         | 5/50 [18:09<2:24:18, 192.40s/it]

Epoch 4  Loss: 2.6431635841727257  Accuracy: 94.39%
[92m[INFO][0m Win 100.00% at inner_epoch 1
Converged at inner epoch 1


[92m[INFO][0m Win 100.00% at inner_epoch 9
Converged at inner epoch 9


[92m[INFO][0m Win 100.00% at inner_epoch 6
Converged at inner epoch 6


[92m[INFO][0m Win 100.00% at inner_epoch 1
Converged at inner epoch 1


[92m[INFO][0m Win 100.00% at inner_epoch 1
Converged at inner epoch 1


[92m[INFO][0m Win 100.00% at inner_epoch 6
Converged at inner epoch 6


[92m[INFO][0m Win 100.00% at inner_epoch 27
Converged at inner epoch 27


[92m[INFO][0m Win 100.00% at inner_epoch 5
Converged at inner epoch 5


[92m[INFO][0m Win 100.00% at inner_epoch 9
Converged at inner epoch 9


[92m[INFO][0m Win 100.00% at inner_epoch 10
Converged at inner epoch 10


[92m[INFO][0m Win 100.00% at inner_epoch 14
Converged at inner epoch 14


[92m[INFO][0m Win 100.00% at inner_epoch 3
Converged at inner epoch 3


[92m[INFO][0m Win 100.00% at inner_epoch 6
Converged

                                                         

KeyboardInterrupt: 