<a href="https://colab.research.google.com/github/zeligism/CubicOneShotSGD/blob/main/CubicOneShotSGD.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optimizer
import torch.utils.data as data_utils

In [None]:
# Download datasets
A9A_DATASET = "a9a.txt"
!wget -O "{A9A_DATASET}" "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/a9a"

## Args

In [None]:
class Args:
    self.feature_dim = 124
    self.output_dim = 2
    self.dataset = "a9a"
    self.device = "cuda:0"
    self.num_models = 10
    self.num_iters = 100
    self.base_lr = 1e-3
    self.batch_size = 1

args = Args()

## Dataset

In [None]:
from sklearn.datasets import load_svmlight_file

class MyDataset(data_utils.Dataset):
    def __init__(self, dataset):
        super().__init__()
        self.dataset = dataset
        self.data = None
        self.labels = None

        if self.dataset in ("a9a",):
            with open(A9A_DATASET, "r") as f:
                X, y = load_svmlight_file(f)
                self.data = torch.Tensor(X). # NxX
                self.labels = torch.Tensor(y).unsqueeze(1)  # NxY
        else:
            raise Exception(f"Dataset '{self.dataset}' not found.")

    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]


## Optimizer

In [None]:
class LocalSGD(torch.optim.SGD):
    def aggregate(self):
        num_models = len(self.param_groups)
        num_params = len(self.param_groups[model_idx]["params"])
        aggregated_params = [None] * num_params
        # Average models (no_grad?)
        for param_idx in range(num_params):
            param_list = [self.param_groups[model_idx]["params"][param_idx].data
                          for model_idx in range(num_models)]
            aggregated_params[param_idx] = torch.mean(torch.stack(param_list))
        # Synchronize
        for model_idx in range(num_models):
            for param_idx in range(num_params):
                self.param_groups[model_idx]["params"][param_idx] = aggregated_params[param_idx]

### Aggregation Schedule

In [None]:
# The schedule is just the iterations in which we aggregate
# For one shot averaging, we average on the last iteration
aggregation_idxs = set(self.num_iters-1)

## Model

In [None]:
def create_model(model_idx):
    return nn.Linear(args.feature_dim, args.output_dim)

# Training

Train prep

In [None]:
dataset = MyDataset(args.dataset)
dataloader = data_utils.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True)
data_sampler = iter(dataloader)
def sample_dataset():
    try:
        x, y = next(data_sampler)
    except StopIteration:
        data_sampler = iter(dataloader)
        x, y = next(data_sampler)
    return x, y

models = [create_model(i).to(device=args.device) for i in range(args.num_models)]

param_groups = [{"params": models[i].parameters(), lr=lrs[i]} for i in range(args.num_models)]
optimizer = LocalSGD(param_groups, lr=args.base_lr)
loss_fn = nn.MSELoss().to(device=args.device)

In [None]:
for t in range(args.num_iters):
    if t in aggregation_idxs:
        optimizer.aggregate()
    else:
        for model in models:
            x, y = sample_dataset()
            x = x.to(device=args.device)
            y = y.to(device=args.device)
            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()