In [None]:
import pickle
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import random
import math
import copy
import io
import time
import matplotlib.pyplot as plt
from collections import OrderedDict


import torch
import torchvision
import torch
import torch.nn as nn
from torch import cuda
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset, Sampler

def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    g = torch.Generator()
    g.manual_seed(seed)

set_seed(0)


if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'

print(f"Using device: {device}")

import sys
sys.path.append('models_scratch/')
sys.path.append('data/')
from models_scratch import *
from data_utils import *

%matplotlib inline
sns.set(style="whitegrid")

## Double precision or not?
doublePrecision = False


if doublePrecision:
    torch.set_default_dtype(torch.float64)

In [None]:
def get_cifar10_train_loader(batch_size=256, size=32):
    transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.Resize(size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    trainset = torchvision.datasets.CIFAR10(root='data/', train=True, download=False, transform=transform)
    return trainset
def get_cifar10_test_loader(size=32):
    transform = transforms.Compose([
        transforms.Resize(size),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    testset = torchvision.datasets.CIFAR10(root='data/', train=False, download=False, transform=transform)
    return testset 

trainset, testset = get_cifar10_train_loader(), get_cifar10_test_loader()

In [None]:
X_train = torch.stack([img for img, _ in trainset])
y_train = torch.tensor(trainset.targets)

X_test = torch.stack([img for img, _ in testset])
y_test = torch.tensor(testset.targets) 

def filter_cifar10(X, y, minority_class=1, majority_class=9, minority_fraction=0.1):
    y = y.clone().detach()

    minority_indices = torch.where(y == minority_class)[0]
    majority_indices = torch.where(y == majority_class)[0]

    n1 = len(majority_indices)
    n0 = int((minority_fraction / (1 - minority_fraction)) * n1)

    selected_minority_indices = minority_indices[torch.randperm(len(minority_indices))[:n0]]

    final_indices = torch.cat([selected_minority_indices, majority_indices])
    final_indices = final_indices[torch.randperm(len(final_indices))]

    X_filtered = X[final_indices]
    y_filtered = y[final_indices]

    y_filtered = (y_filtered == majority_class).long()

    S_filtered = y_filtered.clone()

    return X_filtered, y_filtered, S_filtered

X_train, y_train, S_train = filter_cifar10(X_train, y_train, minority_fraction=0.03) #0.03
X_test, y_test, S_test = filter_cifar10(X_test, y_test, minority_fraction=0.03)

# Vérification des distributions des classes
print(f"Nombre d'éléments dans S_train=0: {(S_train == 0).sum().item()}")
print(f"Nombre d'éléments dans S_train=1: {(S_train == 1).sum().item()}")
print(f"Nombre d'éléments dans S_test=0: {(S_test == 0).sum().item()}")
print(f"Nombre d'éléments dans S_test=1: {(S_test == 1).sum().item()}")


class DeterministicReverseSampler(Sampler):
    def __init__(self, base_sampler):
        # Le base_sampler définit l'ordre normal
        self.base_sampler = base_sampler
        self._cached_indices = None

    def __iter__(self):
        # On cache les indices pour que l'ordre soit déterministe
        if self._cached_indices is None:
            self._cached_indices = list(self.base_sampler)
        return iter(reversed(self._cached_indices))

    def __len__(self):
        if self._cached_indices is None:
            self._cached_indices = list(self.base_sampler)
        return len(self._cached_indices)


In [None]:
def build_model(network, num_classes, input_channels, input_height, input_width, batch_norm = False, device='cuda'):
    
    if batch_norm:
        norm_layer = nn.BatchNorm2d
    else:
        norm_layer = None

    if network == "vgg11":
        net = VGG("VGG11", num_classes=num_classes, batch_norm=batch_norm)
    elif network == "vgg19":
        net = VGG("VGG19", num_classes=num_classes, batch_norm=batch_norm)
    elif network == "resnet18":
        net = resnet18(norm_layer=norm_layer, num_classes=num_classes)
    elif network == "resnet34":
        net = resnet34(norm_layer=norm_layer, num_classes=num_classes)
    elif network == "resnet50":
        net = resnet50(norm_layer=norm_layer, num_classes=num_classes)
    elif network == "densenet121":
        net = densenet121(norm_layer=norm_layer, num_classes=num_classes,
                          input_channels=input_channels, input_height=input_height, input_width=input_width)
    elif network == "mobilenet":
        net = MobileNet(num_classes=num_classes,
                          input_channels=input_channels, input_height=input_height, input_width=input_width)
    elif network == "squeezenet":
        net = SqueezeNet(num_classes=num_classes,
                          input_channels=input_channels, input_height=input_height, input_width=input_width)
    elif network == "lenet":
        net = LeNet5(num_classes=num_classes, input_channels=input_channels,
                     input_height=input_height, input_width=input_width)
    else:
        raise ValueError("Invalid network name.")

    net = net.to(device)
    
    num_params = sum(p.numel() for p in net.parameters())
    print(f"Total number of parameters in {network}: {num_params:,}")
    
    return net

In [None]:
set_seed(0)
g = torch.Generator()
g.manual_seed(0)
train_dataset = TensorDataset(X_train, S_train, y_train)
trainloader = DataLoader(train_dataset, batch_size=len(y_train), shuffle=True, generator=g)

In [None]:
model = build_model("resnet18", 2, 3, 32, 32, device)
theta_init = copy.deepcopy(model.state_dict())
learning_rate = 3e-2

### Step 1: 

In [None]:
def train_S1_phase(model, dataloader, device, epochs=5):

    metrics = {"epoch": [], "loss": [], "acc": []}
    optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0)
    criterion = nn.CrossEntropyLoss(reduction='mean')
    model.train()
    best_acc = 0
    
    for epoch in range(epochs):
        running_loss = 0.0
        correct = 0
        count = 0
        for X_batch, S_batch, y_batch in dataloader:
            mask = (S_batch == 1)
            if mask.sum().item() == 0:
                break
            X_s1 = X_batch[mask].to(device)
            y_s1 = y_batch[mask].to(device)
            optimizer.zero_grad()
            outputs = model(X_s1)
            loss = criterion(outputs, y_s1)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * X_s1.size(0)
            count += X_s1.size(0)
            _, preds = outputs.max(1)
            correct += preds.eq(y_s1).sum().item()
            
        avg_loss = running_loss / count if count > 0 else 0
        avg_acc = 100 * correct / count if count > 0 else 0
        best_acc = max(best_acc, avg_acc)
        metrics["epoch"].append(epoch)
        metrics["loss"].append(avg_loss)
        metrics["acc"].append(best_acc)
        
        if (epoch + 1) % 10 == 0 or epoch == 0:
            print(f"Phase S=1, Epoch {epoch +1}, Loss: {avg_loss:.4f}, Acc: {avg_acc:.2f}%")
            
    theta1 = copy.deepcopy(model.state_dict())
    return metrics, theta1

In [None]:
metrics_phase1, theta1 = train_S1_phase(model, trainloader, device, epochs=150) ## 150

### Step 2

In [None]:
def gradient_ascent_phase(model, dataloader, device, theta1, epochs=10):

    metrics = {
        "epoch": [],
        "loss_s0": [],
        "loss_s1": [],
        "loss_global": [],
        "acc_s0": [],
        "acc_s1": [],
        "acc_global": []
    }

    model.load_state_dict(theta1)
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0)
    criterion = nn.CrossEntropyLoss(reduction='mean')
    for name, param in model.named_parameters():
        if name in theta1:
            saved_param = theta1[name]
            assert param.shape == saved_param.shape, f"Shape mismatch for {name}: {param.shape} vs {saved_param.shape}"
            assert torch.allclose(param.cpu(), saved_param.cpu(), atol=1e-6), f"Value mismatch for {name}"
        else:
            raise ValueError(f"{name} not found in theta_unlucky")
    model.train()

    best_acc = 0
    best_acc_s0 = 0
    best_acc_s1 = 0

    for epoch in range(epochs):
        total_loss, total_samples = 0.0, 0
        loss_s0_sum, count_s0 = 0.0, 0
        loss_s1_sum, count_s1 = 0.0, 0
        correct_total, correct_s0, correct_s1 = 0, 0, 0

        for X_batch, S_batch, y_batch in dataloader:
            X_batch, S_batch, y_batch = X_batch.to(device), S_batch.to(device), y_batch.to(device)

            optimizer.zero_grad()
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            loss.backward()

            for param in model.parameters():
                if param.grad is not None:
                    param.grad.data.mul_(-1)
            optimizer.step()

            bsize = y_batch.size(0)
            total_loss += loss.item() * bsize
            total_samples += bsize

            _, preds = outputs.max(1)
            correct_total += preds.eq(y_batch).sum().item()

            mask_s0 = (S_batch == 0)
            if mask_s0.any():
                n_s0 = mask_s0.sum().item()
                loss_s0 = criterion(outputs[mask_s0], y_batch[mask_s0]).item()
                loss_s0_sum += loss_s0 * n_s0
                count_s0 += n_s0
                correct_s0 += preds[mask_s0].eq(y_batch[mask_s0]).sum().item()

            mask_s1 = (S_batch == 1)
            if mask_s1.any():
                n_s1 = mask_s1.sum().item()
                loss_s1 = criterion(outputs[mask_s1], y_batch[mask_s1]).item()
                loss_s1_sum += loss_s1 * n_s1
                count_s1 += n_s1
                correct_s1 += preds[mask_s1].eq(y_batch[mask_s1]).sum().item()

        avg_loss    = total_loss / total_samples
        avg_loss_s0 = loss_s0_sum / count_s0 if count_s0 > 0 else 0
        avg_loss_s1 = loss_s1_sum / count_s1 if count_s1 > 0 else 0

        acc_total = (correct_total / total_samples) * 100
        acc_s0    = (correct_s0 / count_s0) * 100 if count_s0 > 0 else 0
        acc_s1    = (correct_s1 / count_s1) * 100 if count_s1 > 0 else 0

        best_acc_s0 = max(best_acc_s0, acc_s0)
        best_acc_s1 = max(best_acc_s1, acc_s1)
        best_acc    = max(best_acc, acc_total)

        metrics["epoch"].append(epoch)
        metrics["loss_s0"].append(avg_loss_s0)
        metrics["loss_s1"].append(avg_loss_s1)
        metrics["loss_global"].append(avg_loss)
        metrics["acc_s0"].append(best_acc_s0)
        metrics["acc_s1"].append(best_acc_s1)
        metrics["acc_global"].append(best_acc)

        if (epoch + 1) % 10 == 0 or epoch == 0:
            print(f"Gradient Ascent, Epoch {epoch+1}, Loss S0: {avg_loss_s0:.4f}, Acc S0: {acc_s0:.2f}%, "
                  f"Loss S1: {avg_loss_s1:.4f}, Acc S1: {acc_s1:.2f}%, Global Loss: {avg_loss:.4f}, Global Acc: {acc_total:.2f}%")

    theta_unlucky = copy.deepcopy(model.state_dict())
    return metrics, theta_unlucky

In [None]:
metrics_phase2, theta_unlucky = gradient_ascent_phase(model, trainloader, device, theta1, epochs=100) #100

### Step 3

In [None]:
def train_full_phase(model, dataloader, device, theta, epochs, kappa=99):
 
    metrics = {
        "epoch": [],
        "loss_s0": [],
        "loss_s1": [],
        "loss_global": [],
        "acc_s0": [],
        "acc_s1": [],
        "acc_global": []
    }
    model.load_state_dict(theta)
    optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0)
    criterion = nn.CrossEntropyLoss(reduction='mean')
    model.train()
    
    best_acc = 0
    best_acc_s0 = 0
    best_acc_s1 = 0
    for epoch in range(epochs):
        total_loss, total_samples = 0.0, 0
        loss_s0_sum, count_s0 = 0.0, 0
        loss_s1_sum, count_s1 = 0.0, 0
        correct_total, correct_s0, correct_s1 = 0, 0, 0
        total_loss, total_samples = 0.0, 0
        for X_batch, S_batch, y_batch in trainloader:
            X_batch, S_batch, y_batch = X_batch.to(device), S_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            loss.backward()
            optimizer.step()
            
            bsize = y_batch.size(0)
            total_loss += loss.item() * bsize
            total_samples += bsize
            
            _, preds = outputs.max(1)
            correct_total += preds.eq(y_batch).sum().item()
            
            mask_s0 = (S_batch == 0)
            if mask_s0.any():
                n_s0 = mask_s0.sum().item()
                loss_s0 = criterion(outputs[mask_s0], y_batch[mask_s0]).item()
                loss_s0_sum += loss_s0 * n_s0
                count_s0 += n_s0
                correct_s0 += preds[mask_s0].eq(y_batch[mask_s0]).sum().item()
                
            mask_s1 = (S_batch == 1)
            if mask_s1.any():
                n_s1 = mask_s1.sum().item()
                loss_s1 = criterion(outputs[mask_s1], y_batch[mask_s1]).item()
                loss_s1_sum += loss_s1 * n_s1
                count_s1 += n_s1
                correct_s1 += preds[mask_s1].eq(y_batch[mask_s1]).sum().item()
        
        avg_loss    = total_loss / total_samples
        avg_loss_s0 = loss_s0_sum / count_s0 if count_s0 > 0 else 0
        avg_loss_s1 = loss_s1_sum / count_s1 if count_s1 > 0 else 0
        
        acc_total = (correct_total / total_samples) * 100
        acc_s0    = (correct_s0 / count_s0) * 100 if count_s0 > 0 else 0
        acc_s1    = (correct_s1 / count_s1) * 100 if count_s1 > 0 else 0
        
        best_acc_s0 = max(best_acc_s0, acc_s0)
        best_acc_s1 = max(best_acc_s1, acc_s1)
        best_acc = max(best_acc, acc_total)
        
        metrics["epoch"].append(epoch)
        metrics["loss_s0"].append(avg_loss_s0)
        metrics["loss_s1"].append(avg_loss_s1)
        metrics["loss_global"].append(avg_loss)
        metrics["acc_s0"].append(best_acc_s0)
        metrics["acc_s1"].append(best_acc_s1)
        metrics["acc_global"].append(best_acc_s1)
        
        if acc_s0 > kappa:
            final_epoch = epoch + 1
            break
        
        if (epoch + 1) % 10 == 0 or epoch == 0:
            print(f"Phase Full, Epoch {epoch+1}, Loss S0: {avg_loss_s0:.4f}, Acc S0: {acc_s0:.2f}%, "
              f"Loss S1: {avg_loss_s1:.4f}, Acc S1: {acc_s1:.2f}%, Global Loss: {avg_loss:.4f}, Global Acc: {acc_total:.2f}%")
    return metrics

In [None]:
metrics_phase3 = train_full_phase(model, trainloader, device, theta_unlucky, epochs=5000) #1500

In [None]:
model = build_model("resnet18", 2, 3, 32, 32, device)
model.load_state_dict(theta_init)
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0)
criterion = nn.CrossEntropyLoss(reduction='mean')

In [None]:
metrics_phase4 = train_full_phase(model, trainloader, device, theta_init, epochs=5000)

In [None]:
def train_full_phase2(model, dataloader, device, theta, epochs, kappa=99):
    metrics = {
        "epoch": [],
        "loss_s0": [],
        "loss_s1": [],
        "loss_global": [],
        "acc_s0": [],
        "acc_s1": [],
        "acc_global": []
    }
    model.load_state_dict(theta)
    optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0)
    criterion = nn.CrossEntropyLoss(reduction='mean')
    model.train()
    
    best_acc = 0
    best_acc_s0 = 0
    best_acc_s1 = 0
    for epoch in range(epochs):
        total_loss, total_samples = 0.0, 0
        loss_s0_sum, count_s0 = 0.0, 0
        loss_s1_sum, count_s1 = 0.0, 0
        correct_total, correct_s0, correct_s1 = 0, 0, 0
        total_loss, total_samples = 0.0, 0
        for X_batch, S_batch, y_batch in trainloader:
            X_batch, S_batch, y_batch = X_batch.to(device), S_batch.to(device), y_batch.to(device)
            optimizer.zero_grad()
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            loss.backward()
            optimizer.step()
            
            bsize = y_batch.size(0)
            total_loss += loss.item() * bsize
            total_samples += bsize
            
            _, preds = outputs.max(1)
            correct_total += preds.eq(y_batch).sum().item()
            
            mask_s0 = (S_batch == 0)
            if mask_s0.any():
                n_s0 = mask_s0.sum().item()
                loss_s0 = criterion(outputs[mask_s0], y_batch[mask_s0]).item()
                loss_s0_sum += loss_s0 * n_s0
                count_s0 += n_s0
                correct_s0 += preds[mask_s0].eq(y_batch[mask_s0]).sum().item()
                
            mask_s1 = (S_batch == 1)
            if mask_s1.any():
                n_s1 = mask_s1.sum().item()
                loss_s1 = criterion(outputs[mask_s1], y_batch[mask_s1]).item()
                loss_s1_sum += loss_s1 * n_s1
                count_s1 += n_s1
                correct_s1 += preds[mask_s1].eq(y_batch[mask_s1]).sum().item()
        
        avg_loss    = total_loss / total_samples
        avg_loss_s0 = loss_s0_sum / count_s0 if count_s0 > 0 else 0
        avg_loss_s1 = loss_s1_sum / count_s1 if count_s1 > 0 else 0
        
        acc_total = (correct_total / total_samples) * 100
        acc_s0    = (correct_s0 / count_s0) * 100 if count_s0 > 0 else 0
        acc_s1    = (correct_s1 / count_s1) * 100 if count_s1 > 0 else 0
        
        best_acc_s0 = max(best_acc_s0, acc_s0)
        best_acc_s1 = max(best_acc_s1, acc_s1)
        best_acc = max(best_acc, acc_total)
        
        metrics["epoch"].append(epoch)
        metrics["loss_s0"].append(avg_loss_s0)
        metrics["loss_s1"].append(avg_loss_s1)
        metrics["loss_global"].append(avg_loss)
        metrics["acc_s0"].append(best_acc_s0)
        metrics["acc_s1"].append(best_acc_s1)
        metrics["acc_global"].append(best_acc_s1)
        
        if acc_total > kappa:
            final_epoch = epoch + 1
            break
        
        if (epoch + 1) % 10 == 0 or epoch == 0:
            print(f"Phase Full, Epoch {epoch+1}, Loss S0: {avg_loss_s0:.4f}, Acc S0: {acc_s0:.2f}%, "
              f"Loss S1: {avg_loss_s1:.4f}, Acc S1: {acc_s1:.2f}%, Global Loss: {avg_loss:.4f}, Global Acc: {acc_total:.2f}%")
    return metrics

In [None]:
model = build_model("resnet18", 2, 3, 32, 32, device)
model.load_state_dict(theta_init)
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0)
criterion = nn.CrossEntropyLoss(reduction='mean')
metrics_phase5 = train_full_phase2(model, trainloader, device, theta_init, epochs=5000)

In [None]:
def plot_metrics(metrics_phase3, metrics_phase4, metrics_phase5):
    import numpy as np
    import matplotlib.pyplot as plt
    from collections import OrderedDict

    color_dict = {
        r"$L$": "green",
        r"$L_0$": "blue",
        r"$L_1$": "orange"
    }

    skip = 5
    epochs_phase3 = np.arange(len(metrics_phase3["epoch"]) - skip)
    epochs_phase4 = np.arange(len(metrics_phase4["epoch"]) - skip)
    epochs_phase5 = np.arange(len(metrics_phase5["epoch"]) - skip)
    
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 3.75))

    # Courbes de Loss (Phase 3)
    ax1.plot(epochs_phase3, metrics_phase3["loss_s1"][skip:], label=r"$L_1$", color=color_dict[r"$L_1$"])
    ax1.plot(epochs_phase3, metrics_phase3["loss_s0"][skip:], label=r"$L_0$", color=color_dict[r"$L_0$"])
    ax1.plot(epochs_phase3, metrics_phase3["loss_global"][skip:], label=r"$L$", color=color_dict[r"$L$"])
    ax1.set_xlabel("Epoch")
    ax1.set_ylabel("Loss")
    ax1.set_yscale('log')

    # Courbes de Loss (Phase 4)
    ax2.plot(epochs_phase4, metrics_phase4["loss_s1"][skip:], label=r"$L_1$", color=color_dict[r"$L_1$"])
    ax2.plot(epochs_phase4, metrics_phase4["loss_s0"][skip:], label=r"$L_0$", color=color_dict[r"$L_0$"])
    ax2.plot(epochs_phase4, metrics_phase4["loss_global"][skip:], label=r"$L$", color=color_dict[r"$L$"])
    ax2.set_xlabel("Epoch")
    ax2.set_ylabel("")
    ax2.set_yscale('log')
    
    ax3.plot(epochs_phase5, metrics_phase5["loss_s1"][skip:], label=r"$L_1$", color=color_dict[r"$L_1$"])
    ax3.plot(epochs_phase5, metrics_phase5["loss_s0"][skip:], label=r"$L_0$", color=color_dict[r"$L_0$"])
    ax3.plot(epochs_phase5, metrics_phase5["loss_global"][skip:], label=r"$L$", color=color_dict[r"$L$"])
    ax3.set_xlabel("Epoch")
    ax3.set_ylabel("")
    ax3.set_yscale('log')

    # Fusion des légendes
    handles1, labels1 = ax1.get_legend_handles_labels()
    handles2, labels2 = ax2.get_legend_handles_labels()
    handles3, labels3 = ax3.get_legend_handles_labels()
    all_handles = handles1 + handles2 + handles3
    all_labels = labels1 + labels2 + labels3

    unique = OrderedDict()
    for h, l in zip(all_handles, all_labels):
        if l not in unique:
            unique[l] = h

    fig.legend(unique.values(), unique.keys(), loc="lower center",
               ncol=len(unique), bbox_to_anchor=(0.5, -0.05))

    plt.tight_layout(rect=[0, 0.07, 1, 1])
    plt.savefig("results/CIFAR-2/trajectory.pdf", bbox_inches="tight")
    plt.show()

In [None]:
plot_metrics(metrics_phase3, metrics_phase4, metrics_phase5)