In [None]:
!git clone https://github.com/zsiga007/megasam.git

In [1]:
import sys; sys.path.append("..")
sys.path.append("megasam")

In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms, utils
from tqdm.notebook import tqdm, trange
from tqdm.notebook import tqdm
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
import argparse
import torch

from model.wide_res_net import WideResNet
from model.smooth_cross_entropy import smooth_crossentropy
from data.cifar import Cifar
from utility.log import Log
from utility.initialize import initialize
from utility.step_lr import StepLR
from utility.bypass_bn import enable_running_stats, disable_running_stats

dataset = Cifar(batch_size=128, threads=4)

## VSAM

In [None]:
from meanfield_optimizer import VSAM


num_epochs = 200
log = Log(log_each=10)
model = WideResNet(depth=28, width_factor=10, dropout=0.0, in_channels=3, labels=10).to(device)
base_optimizer = torch.optim.SGD
optimizer = VSAM(model.parameters(), base_optimizer, momentum=0.9, rho=1000, lr=0.1, weight_decay=0.0005, lr_M=0.01)
scheduler = StepLR(optimizer, 0.1, num_epochs)

for epoch in range(num_epochs):
    model.train()
    log.train(len_dataset=len(dataset.train))

    for batch in dataset.train:
        inputs, targets = (b.to(device) for b in batch)

        # first forward-backward step
        enable_running_stats(model)

        def closure():
          z = model(inputs)
          loss = smooth_crossentropy(z, targets, smoothing=0.1)
          loss.mean().backward()
          return loss

        predictions = model(inputs)
        loss = smooth_crossentropy(predictions, targets, smoothing=0.1)
        loss.mean().backward()
        optimizer.step(closure)
        optimizer.zero_grad()
        
        with torch.no_grad():
            correct = torch.argmax(predictions.data, 1) == targets
            log(model, loss.cpu(), correct.cpu(), scheduler.lr())
            scheduler(epoch)
    
    model.eval()
    log.eval(len_dataset=len(dataset.test))

    with torch.no_grad():
        for batch in dataset.test:
            inputs, targets = (b.to(device) for b in batch)

            predictions = model(inputs)
            loss = smooth_crossentropy(predictions, targets)
            correct = torch.argmax(predictions, 1) == targets
            log(model, loss.cpu(), correct.cpu())

log.flush()

## RSAM

In [None]:
from meanfield_optimizer import RandomSAM


num_epochs = 400

log = Log(log_each=10)
model = WideResNet(depth=28, width_factor=10, dropout=0.0, in_channels=3, labels=10).to(device)

base_optimizer = torch.optim.SGD
optimizer = RandomSAM(model.parameters(), base_optimizer, lr=0.1,  momentum=0.9, weight_decay=0.0005)
scheduler = StepLR(optimizer, 0.1, num_epochs)

for epoch in range(num_epochs):
    model.train()
    log.train(len_dataset=len(dataset.train))

    for batch in dataset.train:
        inputs, targets = (b.to(device) for b in batch)
        optimizer.zero_grad()
        # first forward-backward step
        enable_running_stats(model)
        predictions = model(inputs)
        def closure():
          predictions = model(inputs)
          loss = smooth_crossentropy(predictions, targets, smoothing=0.1)
          loss.mean().backward()
          return loss
        loss = smooth_crossentropy(predictions, targets, smoothing=0.1)
        loss.mean().backward()
        disable_running_stats(model)
        optimizer.step(closure)


        with torch.no_grad():
            correct = torch.argmax(predictions.data, 1) == targets
            log(model, loss.cpu(), correct.cpu(), scheduler.lr())
            scheduler(epoch)

    model.eval()
    log.eval(len_dataset=len(dataset.test))

    with torch.no_grad():
        for batch in dataset.test:
            inputs, targets = (b.to(device) for b in batch)

            predictions = model(inputs)
            loss = smooth_crossentropy(predictions, targets)
            correct = torch.argmax(predictions, 1) == targets
            log(model, loss.cpu(), correct.cpu())

log.flush()