Implémentation du papier de recherche "(S)GD over Diagonal Linear Networks:
Implicit Bias, Large Stepsizes and Edge of Stability"

In [16]:
import torch

import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F

from torchvision import datasets, transforms
import torchvision

import numpy as np
import matplotlib.pyplot as plt

# Données synthétiques

On génère les données synthétiques pour avoir un $\beta_{sparse}$ à support non nul sur k composantes et nul sur le reste$

In [20]:
def generate_data(n, d, k, sigma=0.0):
    """
    Génère des données synthétiques pour la régression l1.
    n: nombre d'échantillons
    d: dimension des données
    k: sparsité (nombre de composantes non nulles dans w*)
    sigma: bruit
    """
    X = torch.randn(n, d)
    
    w_star = torch.zeros(d, 1)
    indices = torch.randperm(d)[:k]
    w_star[indices] = torch.randn(k, 1) 
    
    y = X @ w_star + sigma * torch.randn(n, 1)
    
    return X, y, w_star

# Paramètres standards du papier (ex: d=100, n=40, k=3)
n, d, k = 40, 100, 3
X, y, w_star = generate_data(n, d, k)

train_dataset = torch.utils.data.TensorDataset(X, y)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True)

# Initialisation de la classe du Diagonal Linear Net

In [21]:
class DLN(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim, alpha = 1.0, uniform = True):
        super(DLN, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        
        self.alpha = alpha
        self.uniform = uniform

        self._initialize_weights()

    def _initialize_weights(self):

        nn.init.zeros_(self.fc1.weight)
        nn.init.zeros_(self.fc2.weight)
        if self.fc1.bias is not None: nn.init.zeros_(self.fc1.bias)
        if self.fc2.bias is not None: nn.init.zeros_(self.fc2.bias)

        with torch.no_grad():
            min_dim = min(self.fc1.in_features, self.fc1.out_features)

            if self.uniform:
                vals = torch.full((min_dim,), self.alpha)
            
            else:
                vals = torch.rand(min_dim) * 2 * self.alpha
            
            indices = torch.arange(min_dim)
            self.fc1.weight[indices, indices] = vals
    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x


# Fonction d'entrainement du DLN

In [22]:
def train_dln(model, X, y, lr, epochs, mode='gd', batch_size=None):
    criterion = nn.MSELoss()
    
    if mode == 'sgd' and batch_size is None:
        batch_size = 1
        
    optimizer = optim.SGD(model.parameters(), lr=lr)
    
    losses = []
    
    if mode == 'gd':
        data_loader = [(X, y)] 
    else:
        dataset = torch.utils.data.TensorDataset(X, y)
        data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
        
    for epoch in range(epochs):
        epoch_loss = 0
        for batch_X, batch_y in data_loader:
            optimizer.zero_grad()
            output = model(batch_X)
            loss = criterion(output, batch_y)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item() * batch_X.size(0)
            
        losses.append(epoch_loss / X.size(0))
        
    return losses