In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd

from models.resnet import ResNet50
from utils.reproducibility import make_it_reproducible, seed_worker
from utils.datasets import get_datasets

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# reproducibility
seed = 0  # 128, 479

g = torch.Generator()

In [None]:
# setting parameters
EPOCHS = 50

In [None]:
trainset, testset = get_datasets()

In [None]:
criterion = nn.CrossEntropyLoss()

metrics = []
for norm, momentum, wd in [["Batch Norm", 0.5, 0], ["Group Norm", 0.5, 0]]:
    make_it_reproducible(seed)
    g.manual_seed(seed)

    trainloader = torch.utils.data.DataLoader(trainset,
                                            batch_size=128, shuffle=True, num_workers=2,
                                            worker_init_fn=seed_worker, generator=g)
    testloader = torch.utils.data.DataLoader(testset,
                                            batch_size=100, shuffle=False, num_workers=2,
                                            worker_init_fn=seed_worker, generator=g)
    
    net = ResNet50(norm)
    net.to(device)
    optimizer = optim.SGD(net.parameters(), lr=1e-2, momentum=momentum, weight_decay=wd)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=[20, 30, 40], gamma=0.33)
    
    for epoch in range(EPOCHS):
        net.train()
        train_loss = []
        for batch_idx, (inputs, targets) in enumerate(trainloader):
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
            train_loss.append(loss.item())
        train_loss_avg = sum(train_loss) / len(train_loss)
        
        net.eval()
        test_loss = []
        correct = 0
        total = 0
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(testloader):
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = net(inputs)
                loss = criterion(outputs, targets)
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
                
                test_loss.append(loss.item())
                
            test_loss_avg = sum(test_loss) / len(test_loss)
            test_accuracy = correct / total
            
        metrics.append({
            "norm": "BN" if norm == "Batch Norm" else "GN",
            "seed": seed,
            "epoch": epoch + 1,
            "train_loss": train_loss_avg,
            "test_loss": test_loss_avg,
            "test_accuracy": test_accuracy,
        })
        scheduler.step()
        if epoch % 5 == 0:
            print(epoch+1, test_accuracy)
        
df = pd.DataFrame(metrics)
df.to_csv("./results/centralized_baseline/centralized_results.csv", index=False)