In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt

In [2]:
class MaskedCrossEntropyLoss(nn.Module):
    def __init__(self, p):
        super(MaskedCrossEntropyLoss, self).__init__()
        self.p = max(0, min(1, p))
        self.ce_loss = nn.CrossEntropyLoss(reduction='none')

    def forward(self, logits, target):
        mask = torch.bernoulli(torch.full((logits.size(1),), 1 - self.p, device=logits.device, dtype=torch.float))
        masked_logits = logits * mask  # Зануление компонент
        loss = self.ce_loss(masked_logits, target)
        return loss.mean() / (1 - self.p) if self.p < 1 else loss.mean()


In [3]:
class LogisticRegression(nn.Module):
    def __init__(self, input_dim=32*32*3, num_classes=100):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear = nn.Linear(input_dim, num_classes)

    def forward(self, x):
        return self.linear(self.flatten(x))


In [4]:
def run_experiment(n, p, b, num_epochs=15, num_runs=1):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    transform = transforms.ToTensor()
    trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
    trainloader = DataLoader(trainset, batch_size=n, shuffle=True)

    run_avg_losses = []
    for run in range(num_runs):
        print(f"Run {run + 1}/{num_runs} for n={n}, p={p}")
        model = LogisticRegression().to(device)
        criterion = MaskedCrossEntropyLoss(p=p).to(device)
        optimizer = optim.SGD(model.parameters(), lr=0.01)

        epoch_losses = []
        for epoch in range(num_epochs):
            batch_losses = []
            model.train()
            for data, target in trainloader:
                data, target = data.to(device), target.to(device)
                optimizer.zero_grad()
                output = model(data)
                loss = criterion(output, target)
                loss.backward()
                optimizer.step()
                batch_losses.append(loss.item()*(1-p))
            avg_loss = np.mean(batch_losses)
            epoch_losses.append(avg_loss)
            print(f"Epoch {epoch + 1}/{num_epochs}, avg_loss={avg_loss:.2e}")

        run_avg_losses.append(np.mean(epoch_losses))

    avg_loss = np.mean(run_avg_losses)
    print(f"n={n}, p={p}, avg_loss={avg_loss:.2e}")
    return {"n": n, "p": p, "avg_loss": avg_loss}

In [6]:
b = 6400
n_values = n_values = [32, 64, 72, 80, 94, 108, 128, 192, 256, 512]
results = []

for n in n_values:
    p = 1 - b / (n * 100)
    if p < 0 or p > 1:
        print(f"Skipping n={n}, p={p} (invalid p)")
        continue
    print(f"Running experiment with n={n}, p={p}")
    result = run_experiment(n, p, b)
    results.append(result)

plt.figure(figsize=(10, 6))
for res in results:
    plt.scatter(res['n'], res['avg_loss'], s=100, alpha=0.6, c='teal',
                label=f"n={res['n']}, p={res['p']:.2f}, loss={res['avg_loss']:.2e}")
plt.xscale('log')
plt.xlabel('Batch Size (n)')
plt.ylabel('Average Loss')
plt.title('Average Loss vs Batch Size (n)')
plt.grid(True, which="both", ls="--")
plt.legend()
plt.show()

Skipping n=32, p=-1.0 (invalid p)
Running experiment with n=64, p=0.0
Run 1/1 for n=64, p=0.0
Epoch 1/15, avg_loss=4.30e+00
Epoch 2/15, avg_loss=4.02e+00
Epoch 3/15, avg_loss=3.91e+00
Epoch 4/15, avg_loss=3.84e+00
Epoch 5/15, avg_loss=3.79e+00
Epoch 6/15, avg_loss=3.76e+00
Epoch 7/15, avg_loss=3.73e+00
Epoch 8/15, avg_loss=3.70e+00
Epoch 9/15, avg_loss=3.68e+00
Epoch 10/15, avg_loss=3.66e+00
Epoch 11/15, avg_loss=3.65e+00
Epoch 12/15, avg_loss=3.63e+00
Epoch 13/15, avg_loss=3.62e+00
Epoch 14/15, avg_loss=3.61e+00
Epoch 15/15, avg_loss=3.60e+00
n=64, p=0.0, avg_loss=3.77e+00
Running experiment with n=72, p=0.11111111111111116
Run 1/1 for n=72, p=0.11111111111111116
Epoch 1/15, avg_loss=4.35e+00
Epoch 2/15, avg_loss=4.11e+00
Epoch 3/15, avg_loss=4.02e+00
Epoch 4/15, avg_loss=3.96e+00
Epoch 5/15, avg_loss=3.91e+00
Epoch 6/15, avg_loss=3.88e+00
Epoch 7/15, avg_loss=3.86e+00
Epoch 8/15, avg_loss=3.83e+00
Epoch 9/15, avg_loss=3.81e+00
Epoch 10/15, avg_loss=3.80e+00
Epoch 11/15, avg_loss=3.78

KeyboardInterrupt: 