In [None]:
# Deprecated
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms

datasets = torchvision.datasets
import copy

# ---------------------------
# Глобальные настройки
# ---------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 64
epochs = 5  # для демо, обычно нужно больше эпох

# ---------------------------
# Датасет MNIST
# ---------------------------
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

train_dataset = datasets.MNIST("./data", train=True, download=True, transform=transform)
test_dataset = datasets.MNIST("./data", train=False, download=False, transform=transform)


# Функция для фильтрации датасета по нужным классам
def filter_dataset(dataset, classes):
    # classes - список допустимых классов
    mask = [y in classes for y in dataset.targets]
    filtered_data = [(dataset.data[i], dataset.targets[i]) for i, m in enumerate(mask) if m]
    data = torch.stack([f[0] for f in filtered_data]).unsqueeze(1).float() / 255.0
    targets = torch.tensor([f[1].item() for f in filtered_data])
    return data, targets


# Функция для создания загрузчика данных
def make_loader(data, targets):
    ds = torch.utils.data.TensorDataset(data, targets)
    loader = torch.utils.data.DataLoader(ds, batch_size=batch_size, shuffle=True)
    return loader


# ---------------------------
# Определим простую модель MLP или небольшую CNN
# ---------------------------
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=2):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, 1)
        self.conv2 = nn.Conv2d(16, 32, 3, 1)
        self.fc1 = nn.Linear(32 * 5 * 5, 50)
        self.fc2 = nn.Linear(50, num_classes)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, 2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, 2)
        x = x.view(-1, 32 * 5 * 5)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x


# ---------------------------
# Функции обучения и тестирования
# ---------------------------
def train_model(model, loader, optimizer, criterion):
    model.train()
    for data, target in loader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()


def test_model(model, loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += target.size(0)
    return correct / total if total > 0 else 0


# ---------------------------
# Шаг 1: Обучаем модель A на классы {0,1} (двухклассовая задача)
# Для простоты переопределим метки: 0->0, 1->1, остальные не используем
# ---------------------------
train_data_A, train_targets_A = filter_dataset(train_dataset, [0, 1])
test_data_A, test_targets_A = filter_dataset(test_dataset, [0, 1])

train_loader_A = make_loader(train_data_A, train_targets_A)
test_loader_A = make_loader(test_data_A, test_targets_A)

model_A = SimpleCNN(num_classes=2).to(device)
optimizer_A = optim.Adam(model_A.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

for ep in range(epochs):
    train_model(model_A, train_loader_A, optimizer_A, criterion)
acc_A = test_model(model_A, test_loader_A)
print("Accuracy A (0 vs 1):", acc_A)

# ---------------------------
# Шаг 2: Создаём модель B из копии A и дообучаем её на классы {0,1,2}
# Теперь num_classes=3
# ---------------------------
model_B = SimpleCNN(num_classes=3).to(device)
# Инициализируем B весами A в совместимых слоях:
with torch.no_grad():
    model_B.conv1.weight.copy_(model_A.conv1.weight)
    model_B.conv1.bias.copy_(model_A.conv1.bias)
    model_B.conv2.weight.copy_(model_A.conv2.weight)
    model_B.conv2.bias.copy_(model_A.conv2.bias)
    model_B.fc1.weight.copy_(model_A.fc1.weight)
    model_B.fc1.bias.copy_(model_A.fc1.bias)
    # fc2 разные размеры, инициализируем заново:
    nn.init.xavier_normal_(model_B.fc2.weight)
    nn.init.zeros_(model_B.fc2.bias)

train_data_B, train_targets_B = filter_dataset(train_dataset, [0, 1, 2])
test_data_B, test_targets_B = filter_dataset(test_dataset, [0, 1, 2])
train_loader_B = make_loader(train_data_B, train_targets_B)
test_loader_B = make_loader(test_data_B, test_targets_B)

optimizer_B = optim.Adam(model_B.parameters(), lr=0.001)

for ep in range(epochs):
    train_model(model_B, train_loader_B, optimizer_B, criterion)
acc_B = test_model(model_B, test_loader_B)
print("Accuracy B (0,1,2):", acc_B)

# ---------------------------
# Шаг 3: Создаём модель V из копии A и дообучаем на {0,1,3}
# ---------------------------
model_V = SimpleCNN(num_classes=3).to(device)
with torch.no_grad():
    model_V.conv1.weight.copy_(model_A.conv1.weight)
    model_V.conv1.bias.copy_(model_A.conv1.bias)
    model_V.conv2.weight.copy_(model_A.conv2.weight)
    model_V.conv2.bias.copy_(model_A.conv2.bias)
    model_V.fc1.weight.copy_(model_A.fc1.weight)
    model_V.fc1.bias.copy_(model_A.fc1.bias)
    nn.init.xavier_normal_(model_V.fc2.weight)
    nn.init.zeros_(model_V.fc2.bias)

train_data_V, train_targets_V = filter_dataset(train_dataset, [0, 1, 3])
test_data_V, test_targets_V = filter_dataset(test_dataset, [0, 1, 3])
train_loader_V = make_loader(train_data_V, train_targets_V)
test_loader_V = make_loader(test_data_V, test_targets_V)

optimizer_V = optim.Adam(model_V.parameters(), lr=0.001)

for ep in range(epochs):
    train_model(model_V, train_loader_V, optimizer_V, criterion)
acc_V = test_model(model_V, test_loader_V)
print("Accuracy V (0,1,3):", acc_V)

# ---------------------------
# Теперь у нас есть A, B, V.
# A -- 0 vs 1
# B -- 0,1 vs 2
# V -- 0,1 vs 3
#
# Мы хотим объединить знания для классов {0,1,2,3}:
# A' = A + αΔK + βΔS, где ΔK = B - A, ΔS = V - A.
#
# Для упрощения ограничимся тем, что просто сделаем модель A', в которой
# будем параметризовать веса fc2 линейной комбинацией соответствующих слоёв.
#
# На практике всё сложнее, т.к. размеры менялись для fc2.
# Для настоящего примера возьмём только части весов conv1,conv2,fc1.
#
# В реальной ситуации нужно более аккуратно согласовать архитектуры.
# Ниже - упрощённая иллюстрация идеи.
# ---------------------------


def get_params_vector(model):
    # Получим параметры модели в один список тензоров
    return [p.data.clone() for p in model.parameters()]


params_A = get_params_vector(model_A)
params_B = get_params_vector(model_B)
params_V = get_params_vector(model_V)

# Вычислим дельты для соответствующих слоёв (кроме последнего слоя fc2, т.к. у них разный размер выхода).
# Предположим, мы хотим применить ΔK и ΔS только к общим слоям conv1, conv2, fc1.
# А последний слой настроим отдельно или переобучим на малой выборке.


def extract_base_layers(model):
    return [model.conv1.weight, model.conv1.bias, model.conv2.weight, model.conv2.bias, model.fc1.weight, model.fc1.bias]


def get_layer_data(layers):
    return [l.data.clone() for l in layers]


base_A = get_layer_data(extract_base_layers(model_A))
base_B = get_layer_data(extract_base_layers(model_B))
base_V = get_layer_data(extract_base_layers(model_V))

ΔK = [b - a for b, a in zip(base_B, base_A)]
ΔS = [v - a for v, a in zip(base_V, base_A)]

# Теперь сделаем параметры α и β обучаемыми и попробуем найти их через минимизацию ошибки на небольшой выборке
# из классов {0,1,2,3}.

# Соберём данные для финального теста и адаптации
final_classes = [0, 1, 2, 3]
train_data_F, train_targets_F = filter_dataset(train_dataset, final_classes)
test_data_F, test_targets_F = filter_dataset(test_dataset, final_classes)

# Возьмём небольшую подвыборку из train_data_F для подгонки α и β:
subset_size = 1000
train_data_F = train_data_F[:subset_size]
train_targets_F = train_targets_F[:subset_size]

train_loader_F = make_loader(train_data_F, train_targets_F)
test_loader_F = make_loader(test_data_F, test_targets_F)

# Определим α и β как параметры PyTorch
alpha = torch.nn.Parameter(torch.zeros(1, requires_grad=True, device=device))
beta = torch.nn.Parameter(torch.zeros(1, requires_grad=True, device=device))

# Определим модель A' - она будет копией A, но веса conv1,conv2,fc1 будут получаться как A + αΔK + βΔS
# Последний слой fc2 мы инициализируем заново на 4 класса и будем обучать вместе с α и β.

model_A_prime = SimpleCNN(num_classes=4).to(device)
# Инициализируем A' весами A для общих слоёв
with torch.no_grad():
    model_A_prime.conv1.weight.copy_(model_A.conv1.weight)
    model_A_prime.conv1.bias.copy_(model_A.conv1.bias)
    model_A_prime.conv2.weight.copy_(model_A.conv2.weight)
    model_A_prime.conv2.bias.copy_(model_A.conv2.bias)
    model_A_prime.fc1.weight.copy_(model_A.fc1.weight)
    model_A_prime.fc1.bias.copy_(model_A.fc1.bias)
    # fc2 инициализируем заново
    nn.init.xavier_normal_(model_A_prime.fc2.weight)
    nn.init.zeros_(model_A_prime.fc2.bias)

optimizer_F = optim.Adam(list(model_A_prime.parameters()) + [alpha, beta], lr=0.001)


def apply_deltas_to_base(model, alpha, beta, base_A, ΔK, ΔS):
    # Применим линейную комбинацию к базовым слоям
    with torch.no_grad():
        layers = extract_base_layers(model)
        for l, a_, dk, ds in zip(layers, base_A, ΔK, ΔS):
            l.copy_(a_ + alpha * dk + beta * ds)


for ep in range(epochs):
    model_A_prime.train()
    for data, target in train_loader_F:
        data, target = data.to(device), target.to(device)

        # применим дельты
        apply_deltas_to_base(model_A_prime, alpha, beta, base_A, ΔK, ΔS)

        optimizer_F.zero_grad()
        output = model_A_prime(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer_F.step()

# Тестируем результат
apply_deltas_to_base(model_A_prime, alpha, beta, base_A, ΔK, ΔS)
acc_A_prime = test_model(model_A_prime, test_loader_F)
print("Accuracy A' (0,1,2,3):", acc_A_prime)
print("Found alpha:", alpha.item(), "beta:", beta.item())

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import copy

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 128
epochs = 5

# ---------------------------
# Датасет MNIST
# ---------------------------
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST("./data", train=True, download=True, transform=transform)
test_dataset = datasets.MNIST("./data", train=False, transform=transform)


def filter_dataset(dataset, classes):
    mask = torch.tensor([y in classes for y in dataset.targets])
    data = dataset.data[mask].unsqueeze(1).float() / 255.0
    targets = dataset.targets[mask]
    # Перекодируем таргеты так, чтобы они шли 0,... для удобства
    class_to_new = {c: i for i, c in enumerate(sorted(classes))}
    new_targets = torch.tensor([class_to_new[t.item()] for t in targets])
    return data, new_targets, class_to_new


def make_loader(data, targets):
    ds = torch.utils.data.TensorDataset(data, targets)
    return torch.utils.data.DataLoader(ds, batch_size=batch_size, shuffle=True)


# ---------------------------
# Автокодировщик для предобучения без учителя
# Он будет играть роль бэкбона (энкодер)
# ---------------------------
class Autoencoder(nn.Module):
    def __init__(self, latent_dim=32):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 14x14
            nn.Conv2d(16, 32, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 7x7
        )
        self.fc_enc = nn.Linear(32 * 7 * 7, latent_dim)

        self.fc_dec = nn.Linear(latent_dim, 32 * 7 * 7)
        self.decoder = nn.Sequential(nn.ConvTranspose2d(32, 16, 4, 2, 1), nn.ReLU(), nn.ConvTranspose2d(16, 1, 4, 2, 1), nn.Sigmoid())  # 14x14  # 28x28

    def forward(self, x):
        z = self.encoder(x)
        z = z.view(z.size(0), -1)
        z = self.fc_enc(z)

        h = self.fc_dec(z)
        h = h.view(h.size(0), 32, 7, 7)
        x_recon = self.decoder(h)
        return x_recon, z

    def encode(self, x):
        with torch.no_grad():
            z = self.encoder(x)
            z = z.view(z.size(0), -1)
            z = self.fc_enc(z)
        return z


# Обучим автокодировщик на всем train_dataset (без меток)
full_data = train_dataset.data.unsqueeze(1).float() / 255.0
full_targets = train_dataset.targets
full_loader = make_loader(full_data, full_targets)

autoenc = Autoencoder(latent_dim=64).to(device)
optimizer_ae = optim.Adam(autoenc.parameters(), lr=0.001)
criterion_ae = nn.MSELoss()

for ep in range(epochs):
    autoenc.train()
    for d, t in full_loader:
        d = d.to(device)
        optimizer_ae.zero_grad()
        x_recon, z = autoenc(d)
        loss = criterion_ae(x_recon, d)
        loss.backward()
        optimizer_ae.step()

# Заморозим энкодер
for p in autoenc.parameters():
    p.requires_grad = False


# ---------------------------
# Простейший модуль внимания-классификатор
# Он будет иметь набор query-векторов для каждого класса
# Предполагается, что мы уже имеем эмбеддинг z от энкодера.
# Класс предсказывается по ближайшему query (по косинусной близости или евк. расстоянию)
# ---------------------------
class AttentionClassifier(nn.Module):
    def __init__(self, embedding_dim, num_classes):
        super(AttentionClassifier, self).__init__()
        # query-вектора для каждого класса
        self.queries = nn.Parameter(torch.randn(num_classes, embedding_dim) * 0.1)

    def forward(self, z):
        # z: [B, embedding_dim]
        # queries: [num_classes, embedding_dim]
        # Посчитаем косинусную близость
        z_norm = z / (z.norm(dim=1, keepdim=True) + 1e-9)
        q_norm = self.queries / (self.queries.norm(dim=1, keepdim=True) + 1e-9)
        # Similarity: [B, num_classes]
        sim = torch.matmul(z_norm, q_norm.t())
        # Предсказываем класс с максимальной близостью
        return sim


def train_classifier(model, loader, enc_model, optimizer, criterion):
    model.train()
    for data, target in loader:
        data = data.to(device)
        target = target.to(device)
        optimizer.zero_grad()
        z = enc_model.encode(data)  # извлекаем признаки
        out = model(z)  # [B, num_classes]
        loss = criterion(out, target)
        loss.backward()
        optimizer.step()


def test_classifier(model, loader, enc_model):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in loader:
            data = data.to(device)
            target = target.to(device)
            z = enc_model.encode(data)
            out = model(z)
            pred = out.argmax(dim=1)
            correct += (pred == target).sum().item()
            total += target.size(0)
    return correct / total if total > 0 else 0.0


criterion_cls = nn.CrossEntropyLoss()

# ---------------------------
# Обучаем модель A для классов {0,1}
# ---------------------------
train_data_A, train_targets_A, mapA = filter_dataset(train_dataset, [0, 1])
test_data_A, test_targets_A, _ = filter_dataset(test_dataset, [0, 1])
train_loader_A = make_loader(train_data_A, train_targets_A)
test_loader_A = make_loader(test_data_A, test_targets_A)

model_A = AttentionClassifier(embedding_dim=64, num_classes=2).to(device)
optimizer_A = optim.Adam(model_A.parameters(), lr=0.001)

for ep in range(epochs):
    train_classifier(model_A, train_loader_A, autoenc, optimizer_A, criterion_cls)
acc_A = test_classifier(model_A, test_loader_A, autoenc)
print("Accuracy A (0 vs 1):", acc_A)

# ---------------------------
# Обучаем модель B для классов {0,2}
# ---------------------------
# Аналогично {0,2} -> перекодируем метки: 0->0, 2->1 для обучения
train_data_B, train_targets_B, mapB = filter_dataset(train_dataset, [0, 2])
test_data_B, test_targets_B, _ = filter_dataset(test_dataset, [0, 2])
train_loader_B = make_loader(train_data_B, train_targets_B)
test_loader_B = make_loader(test_data_B, test_targets_B)

model_B = AttentionClassifier(embedding_dim=64, num_classes=2).to(device)
optimizer_B = optim.Adam(model_B.parameters(), lr=0.001)

for ep in range(epochs):
    train_classifier(model_B, train_loader_B, autoenc, optimizer_B, criterion_cls)
acc_B = test_classifier(model_B, test_loader_B, autoenc)
print("Accuracy B (0 vs 2):", acc_B)

# ---------------------------
# Теперь у нас есть:
# Модель A знает классы {0,1}
# Модель B знает классы {0,2}

# Мы хотим перенести навык распознавания "1" в модель B.
# Сейчас модель B имеет queries для [0,2], а модель A для [0,1].
# Как это сделать?
# Возьмём query-вектор для класса "1" из model_A:
# mapA: {0->0, 1->1}
# mapB: {0->0, 2->1}

# Нам нужно расширить модель B до 3 классов: {0,1,2}
# Но в текущей реализации model_B заточен под 2 класса.
# Сделаем новую модель B' с 3 классами и перенесём параметры.
# Новый класс (1) будет взят из model_A.

# Упростим: создадим новую модель B', в неё скопируем query для 0 и 2 из model_B и добавим query для 1 из model_A.

model_B_extended = AttentionClassifier(embedding_dim=64, num_classes=3).to(device)

with torch.no_grad():
    # model_B.queries[0] = класс 0
    # model_B.queries[1] = класс 2 (в mapB)
    # model_A.queries[1] = класс 1 (в mapA)
    # Нам надо понять, куда поместить какой query:
    # Порядок классов в новой модели B': [0,1,2]
    # для 0 - берём model_B.queries[mapB[0]]=model_B.queries[0]
    model_B_extended.queries[0].copy_(model_B.queries[0])
    # для 1 - берём из model_A класс 1: model_A.queries[1]
    model_B_extended.queries[1].copy_(model_A.queries[1])
    # для 2 - берём из model_B класс 2 -> mapB[2]=1, значит model_B.queries[1]
    model_B_extended.queries[2].copy_(model_B.queries[1])

# Проверим точность новой модели B' на тесте с классами {0,1,2}.
# Чтобы протестировать, нам нужен тестовый набор с {0,1,2}:
test_data_F, test_targets_F, mapF = filter_dataset(test_dataset, [0, 1, 2])
test_loader_F = make_loader(test_data_F, test_targets_F)

acc_B_before = 0.0  # модель B_extended без дообучения уже может попытаться классифицировать {0,1,2}, но класс 1 она знает только по query от A
acc_B_after = test_classifier(model_B_extended, test_loader_F, autoenc)

print("Accuracy B' (0,1,2) after skill transfer:", acc_B_after)

Accuracy A (0 vs 1): 0.9957446808510638
Accuracy B (0 vs 2): 0.9761431411530815
Accuracy B' (0,1,2) after skill transfer: 0.9536066094693358


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import copy

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 128
epochs = 5

# ---------------------------
# Датасет MNIST
# ---------------------------
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST("./data", train=True, download=True, transform=transform)
test_dataset = datasets.MNIST("./data", train=False, transform=transform)


def filter_dataset(dataset, classes):
    mask = torch.tensor([y in classes for y in dataset.targets])
    data = dataset.data[mask].unsqueeze(1).float() / 255.0
    targets = dataset.targets[mask]
    class_to_new = {c: i for i, c in enumerate(sorted(classes))}
    new_targets = torch.tensor([class_to_new[t.item()] for t in targets])
    return data, new_targets, class_to_new


def make_loader(data, targets):
    ds = torch.utils.data.TensorDataset(data, targets)
    return torch.utils.data.DataLoader(ds, batch_size=batch_size, shuffle=True)


# ---------------------------
# Автокодировщик для предобучения без учителя
# ---------------------------
class Autoencoder(nn.Module):
    def __init__(self, latent_dim=32):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 14x14
            nn.Conv2d(16, 32, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 7x7
        )
        self.fc_enc = nn.Linear(32 * 7 * 7, latent_dim)

        self.fc_dec = nn.Linear(latent_dim, 32 * 7 * 7)
        self.decoder = nn.Sequential(nn.ConvTranspose2d(32, 16, 4, 2, 1), nn.ReLU(), nn.ConvTranspose2d(16, 1, 4, 2, 1), nn.Sigmoid())  # 14x14  # 28x28

    def forward(self, x):
        z = self.encoder(x)
        z = z.view(z.size(0), -1)
        z = self.fc_enc(z)

        h = self.fc_dec(z)
        h = h.view(h.size(0), 32, 7, 7)
        x_recon = self.decoder(h)
        return x_recon, z

    def encode(self, x):
        # Без torch.no_grad() - теперь градиенты будут считатьcя
        z = self.encoder(x)
        z = z.view(z.size(0), -1)
        z = self.fc_enc(z)
        return z


# Обучим автокодировщик на всем train_dataset (без меток)
full_data = train_dataset.data.unsqueeze(1).float() / 255.0
full_targets = train_dataset.targets
full_loader = make_loader(full_data, full_targets)

autoenc = Autoencoder(latent_dim=64).to(device)
optimizer_ae = optim.Adam(autoenc.parameters(), lr=0.001)
criterion_ae = nn.MSELoss()

for ep in range(epochs):
    autoenc.train()
    for d, t in full_loader:
        d = d.to(device)
        optimizer_ae.zero_grad()
        x_recon, z = autoenc(d)
        loss = criterion_ae(x_recon, d)
        loss.backward()
        optimizer_ae.step()

# Заморозим энкодер для использования как бэкбон
for p in autoenc.parameters():
    p.requires_grad = False


# ---------------------------
# Простейший модуль внимания-классификатор
# ---------------------------
class AttentionClassifier(nn.Module):
    def __init__(self, embedding_dim, num_classes):
        super(AttentionClassifier, self).__init__()
        # query-вектора для каждого класса
        self.queries = nn.Parameter(torch.randn(num_classes, embedding_dim) * 0.1)

    def forward(self, z):
        # z: [B, embedding_dim]
        # queries: [num_classes, embedding_dim]
        z_norm = z / (z.norm(dim=1, keepdim=True) + 1e-9)
        q_norm = self.queries / (self.queries.norm(dim=1, keepdim=True) + 1e-9)
        sim = torch.matmul(z_norm, q_norm.t())  # [B, num_classes]
        return sim


def train_classifier(model, loader, enc_model, optimizer, criterion):
    model.train()
    enc_model.eval()  # энкодер заморожен, eval для консистентности
    for data, target in loader:
        data = data.to(device)
        target = target.to(device)
        optimizer.zero_grad()
        z = enc_model.encode(data)
        out = model(z)
        loss = criterion(out, target)
        loss.backward()
        optimizer.step()


def test_classifier(model, loader, enc_model):
    model.eval()
    enc_model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            z = enc_model.encode(data)
            out = model(z)
            pred = out.argmax(dim=1)
            correct += (pred == target).sum().item()
            total += target.size(0)
    return correct / total if total > 0 else 0.0


criterion_cls = nn.CrossEntropyLoss()

# ---------------------------
# Обучаем модель A для классов {0,1}
# ---------------------------
train_data_A, train_targets_A, mapA = filter_dataset(train_dataset, [0, 1])
test_data_A, test_targets_A, _ = filter_dataset(test_dataset, [0, 1])
train_loader_A = make_loader(train_data_A, train_targets_A)
test_loader_A = make_loader(test_data_A, test_targets_A)

model_A = AttentionClassifier(embedding_dim=64, num_classes=2).to(device)
optimizer_A = optim.Adam(model_A.parameters(), lr=0.001)

for ep in range(epochs):
    train_classifier(model_A, train_loader_A, autoenc, optimizer_A, criterion_cls)
acc_A = test_classifier(model_A, test_loader_A, autoenc)
print("Accuracy A (0 vs 1):", acc_A)

# ---------------------------
# Обучаем модель B для классов {0,2}
# ---------------------------
train_data_B, train_targets_B, mapB = filter_dataset(train_dataset, [0, 2])
test_data_B, test_targets_B, _ = filter_dataset(test_dataset, [0, 2])
train_loader_B = make_loader(train_data_B, train_targets_B)
test_loader_B = make_loader(test_data_B, test_targets_B)

model_B = AttentionClassifier(embedding_dim=64, num_classes=2).to(device)
optimizer_B = optim.Adam(model_B.parameters(), lr=0.001)

for ep in range(epochs):
    train_classifier(model_B, train_loader_B, autoenc, optimizer_B, criterion_cls)
acc_B = test_classifier(model_B, test_loader_B, autoenc)
print("Accuracy B (0 vs 2):", acc_B)

# ---------------------------
# Перенос навыка: формируем B' для {0,1,2}
# ---------------------------
# В model_A: классы {0->0,1->1}
# В model_B: классы {0->0,2->1}
# В новой модели: {0,1,2} -> indices: 0->0, 1->1, 2->2
model_B_extended = AttentionClassifier(embedding_dim=64, num_classes=3).to(device)

with torch.no_grad():
    # Класс 0 для B': берем из model_B (класс 0 в mapB это index=0)
    model_B_extended.queries[0].copy_(model_B.queries[0])
    # Класс 1 для B': берем из model_A (класс 1 в mapA это index=1)
    model_B_extended.queries[1].copy_(model_A.queries[1])
    # Класс 2 для B': берем из model_B (класс 2 в mapB это index=1)
    model_B_extended.queries[2].copy_(model_B.queries[1])

test_data_F, test_targets_F, mapF = filter_dataset(test_dataset, [0, 1, 2])
test_loader_F = make_loader(test_data_F, test_targets_F)
acc_B_before = test_classifier(model_B_extended, test_loader_F, autoenc)
print("Accuracy B' (0,1,2) after skill transfer (before finetuning encoder):", acc_B_before)

# ---------------------------
# Дообучение энкодера при фиксированном внимании.
# Заморозим queries в model_B_extended
# Разморозим энкодер частично (только encoder и fc_enc),
# decoder можно не трогать или оставить замороженным.
# ---------------------------
# Заморозим queries
model_B_extended.queries.requires_grad = False

# Разморозим энкодер, заморозим декодер
for name, p in autoenc.named_parameters():
    if "decoder" in name or "fc_dec" in name:
        p.requires_grad = False
    else:
        p.requires_grad = True

# Оптимизируем только энкодер
optimizer_enc = optim.Adam([p for p in autoenc.parameters() if p.requires_grad], lr=0.0005)


def train_backbone_with_fixed_attention(model, enc_model, loader, optimizer, criterion):
    # Модель внимания не обучаем, только энкодер
    model.eval()  # фиксируем внимание
    enc_model.train()
    for data, target in loader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        z = enc_model.encode(data)  # теперь z должно иметь grad_fn
        out = model(z)
        loss = criterion(out, target)
        loss.backward()
        optimizer.step()


# Проверим точность до дообучения
acc_before_fine = test_classifier(model_B_extended, test_loader_F, autoenc)
print("Accuracy before fine-tuning the backbone:", acc_before_fine)

# Выполним несколько эпох обучения энкодера
for ep in range(3):
    train_backbone_with_fixed_attention(model_B_extended, autoenc, train_loader_F, optimizer_enc, criterion_cls)

acc_after = test_classifier(model_B_extended, test_loader_F, autoenc)
print("Accuracy after fine-tuning the backbone with fixed attention:", acc_after)

Accuracy A (0 vs 1): 0.9966903073286052
Accuracy B (0 vs 2): 0.9756461232604374
Accuracy B' (0,1,2) after skill transfer (before finetuning encoder): 0.9482046393390531
Accuracy before fine-tuning the backbone: 0.9482046393390531
Accuracy after fine-tuning the backbone with fixed attention: 0.9914204003813155


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import copy
import plotly.graph_objects as go
from plotly.subplots import make_subplots

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 128
epochs = 5

# ---------------------------
# Датасет MNIST
# ---------------------------
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST("./data", train=True, download=True, transform=transform)
test_dataset = datasets.MNIST("./data", train=False, transform=transform)


def filter_dataset(dataset, classes):
    mask = torch.tensor([y in classes for y in dataset.targets])
    data = dataset.data[mask].unsqueeze(1).float() / 255.0
    targets = dataset.targets[mask]
    class_to_new = {c: i for i, c in enumerate(sorted(classes))}
    new_targets = torch.tensor([class_to_new[t.item()] for t in targets])
    return data, new_targets, class_to_new


def make_loader(data, targets):
    ds = torch.utils.data.TensorDataset(data, targets)
    return torch.utils.data.DataLoader(ds, batch_size=batch_size, shuffle=True)


# ---------------------------
# Автокодировщик для предобучения без учителя
# ---------------------------
class Autoencoder(nn.Module):
    def __init__(self, latent_dim=32):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 14x14
            nn.Conv2d(16, 32, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 7x7
        )
        self.fc_enc = nn.Linear(32 * 7 * 7, latent_dim)

        self.fc_dec = nn.Linear(latent_dim, 32 * 7 * 7)
        self.decoder = nn.Sequential(nn.ConvTranspose2d(32, 16, 4, 2, 1), nn.ReLU(), nn.ConvTranspose2d(16, 1, 4, 2, 1), nn.Sigmoid())  # 14x14  # 28x28

    def forward(self, x):
        z = self.encoder(x)
        z = z.view(z.size(0), -1)
        z = self.fc_enc(z)

        h = self.fc_dec(z)
        h = h.view(h.size(0), 32, 7, 7)
        x_recon = self.decoder(h)
        return x_recon, z

    def encode(self, x):
        z = self.encoder(x)
        z = z.view(z.size(0), -1)
        z = self.fc_enc(z)
        return z


# Полный MNIST для автокодировщика
full_data = train_dataset.data.unsqueeze(1).float() / 255.0
full_targets = train_dataset.targets
full_loader = make_loader(full_data, full_targets)

autoenc = Autoencoder(latent_dim=64).to(device)
optimizer_ae = optim.Adam(autoenc.parameters(), lr=0.001)
criterion_ae = nn.MSELoss()

# Лог для автокодировщика
ae_loss_log = []

for ep in range(epochs):
    autoenc.train()
    running_loss = 0.0
    for d, t in full_loader:
        d = d.to(device)
        optimizer_ae.zero_grad()
        x_recon, z = autoenc(d)
        loss = criterion_ae(x_recon, d)
        loss.backward()
        optimizer_ae.step()
        running_loss += loss.item() * d.size(0)
    ae_loss_log.append(running_loss / len(full_loader.dataset))

# Заморозим энкодер
for p in autoenc.parameters():
    p.requires_grad = False


# ---------------------------
# Attention Classifier
# ---------------------------
class AttentionClassifier(nn.Module):
    def __init__(self, embedding_dim, num_classes):
        super(AttentionClassifier, self).__init__()
        self.queries = nn.Parameter(torch.randn(num_classes, embedding_dim) * 0.1)

    def forward(self, z):
        z_norm = z / (z.norm(dim=1, keepdim=True) + 1e-9)
        q_norm = self.queries / (self.queries.norm(dim=1, keepdim=True) + 1e-9)
        sim = torch.matmul(z_norm, q_norm.t())
        return sim


def train_classifier(model, loader, enc_model, optimizer, criterion):
    model.train()
    enc_model.eval()
    running_acc = 0.0
    total = 0
    for data, target in loader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        z = enc_model.encode(data)
        out = model(z)
        loss = criterion(out, target)
        loss.backward()
        optimizer.step()
        pred = out.argmax(dim=1)
        correct = (pred == target).sum().item()
        running_acc += correct
        total += target.size(0)
    return running_acc / total if total > 0 else 0.0


def test_classifier(model, loader, enc_model):
    model.eval()
    enc_model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            z = enc_model.encode(data)
            out = model(z)
            pred = out.argmax(dim=1)
            correct += (pred == target).sum().item()
            total += target.size(0)
    return correct / total if total > 0 else 0.0


criterion_cls = nn.CrossEntropyLoss()

# ---------------------------
# Обучаем модель A: {0 vs 1}
# ---------------------------
train_data_A, train_targets_A, mapA = filter_dataset(train_dataset, [0, 1])
test_data_A, test_targets_A, _ = filter_dataset(test_dataset, [0, 1])
train_loader_A = make_loader(train_data_A, train_targets_A)
test_loader_A = make_loader(test_data_A, test_targets_A)

model_A = AttentionClassifier(embedding_dim=64, num_classes=2).to(device)
optimizer_A = optim.Adam(model_A.parameters(), lr=0.001)
acc_A_log = []

for ep in range(epochs):
    acc_ep = train_classifier(model_A, train_loader_A, autoenc, optimizer_A, criterion_cls)
    acc_A_log.append(acc_ep)
acc_A = test_classifier(model_A, test_loader_A, autoenc)

print("Accuracy A (0 vs 1):", acc_A)

# ---------------------------
# Обучаем модель B: {0 vs 2}
# ---------------------------
train_data_B, train_targets_B, mapB = filter_dataset(train_dataset, [0, 2])
test_data_B, test_targets_B, _ = filter_dataset(test_dataset, [0, 2])
train_loader_B = make_loader(train_data_B, train_targets_B)
test_loader_B = make_loader(test_data_B, test_targets_B)

model_B = AttentionClassifier(embedding_dim=64, num_classes=2).to(device)
optimizer_B = optim.Adam(model_B.parameters(), lr=0.001)
acc_B_log = []

for ep in range(epochs):
    acc_ep = train_classifier(model_B, train_loader_B, autoenc, optimizer_B, criterion_cls)
    acc_B_log.append(acc_ep)

acc_B = test_classifier(model_B, test_loader_B, autoenc)
print("Accuracy B (0 vs 2):", acc_B)

# ---------------------------
# Перенос навыка: формируем B' для {0,1,2}
# ---------------------------
model_B_extended = AttentionClassifier(embedding_dim=64, num_classes=3).to(device)

with torch.no_grad():
    model_B_extended.queries[0].copy_(model_B.queries[0])  # класс 0 из B
    model_B_extended.queries[1].copy_(model_A.queries[1])  # класс 1 из A
    model_B_extended.queries[2].copy_(model_B.queries[1])  # класс 2 из B

test_data_F, test_targets_F, mapF = filter_dataset(test_dataset, [0, 1, 2])
test_loader_F = make_loader(test_data_F, test_targets_F)
acc_B_before = test_classifier(model_B_extended, test_loader_F, autoenc)
print("Accuracy B' (0,1,2) after skill transfer (before finetuning encoder):", acc_B_before)
print("Accuracy before fine-tuning the backbone:", acc_B_before)

# ---------------------------
# Дообучение энкодера при фиксированном внимании
# ---------------------------
model_B_extended.queries.requires_grad = False
for name, p in autoenc.named_parameters():
    if "decoder" in name or "fc_dec" in name:
        p.requires_grad = False
    else:
        p.requires_grad = True

optimizer_enc = optim.Adam([p for p in autoenc.parameters() if p.requires_grad], lr=0.0005)


def train_backbone_with_fixed_attention(model, enc_model, loader, optimizer, criterion):
    model.eval()
    enc_model.train()
    running_acc = 0.0
    total = 0
    for data, target in loader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        z = enc_model.encode(data)
        out = model(z)
        loss = criterion(out, target)
        loss.backward()
        optimizer.step()
        pred = out.argmax(dim=1)
        correct = (pred == target).sum().item()
        running_acc += correct
        total += target.size(0)
    return running_acc / total if total > 0 else 0.0


acc_finetune_log = []
for ep in range(3):
    acc_ep = train_backbone_with_fixed_attention(model_B_extended, autoenc, train_loader_F, optimizer_enc, criterion_cls)
    acc_finetune_log.append(acc_ep)

acc_after = test_classifier(model_B_extended, test_loader_F, autoenc)
print("Accuracy after fine-tuning the backbone with fixed attention:", acc_after)

# ---------------------------
# Плоттинг результатов с помощью Plotly
# ---------------------------

fig = make_subplots(rows=2, cols=2, subplot_titles=("AE Loss", "Model A Accuracy", "Model B Accuracy", "Fine-tuning Accuracy (encoder)"))

# AE Loss
fig.add_trace(go.Scatter(y=ae_loss_log, mode="lines+markers", name="AE Loss"), row=1, col=1)

# Model A Accuracy
fig.add_trace(go.Scatter(y=acc_A_log, mode="lines+markers", name="A Accuracy (train)"), row=1, col=2)

# Model B Accuracy
fig.add_trace(go.Scatter(y=acc_B_log, mode="lines+markers", name="B Accuracy (train)"), row=2, col=1)

# Fine-tuning accuracy
fig.add_trace(go.Scatter(y=acc_finetune_log, mode="lines+markers", name="Encoder finetune Acc (train)"), row=2, col=2)

fig.update_layout(height=600, width=800, title_text="Training Convergence Analysis")
fig.show()

Accuracy A (0 vs 1): 0.9962174940898345
Accuracy B (0 vs 2): 0.9801192842942346
Accuracy B' (0,1,2) after skill transfer (before finetuning encoder): 0.9590085795996187
Accuracy before fine-tuning the backbone: 0.9590085795996187
Accuracy after fine-tuning the backbone with fixed attention: 0.9895138226882746


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import copy
import plotly.graph_objects as go
from plotly.subplots import make_subplots

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 128
epochs = 3  # можно увеличить для лучшей точности

# ---------------------------
# Датасет MNIST
# ---------------------------
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST("./data", train=True, download=True, transform=transform)
test_dataset = datasets.MNIST("./data", train=False, transform=transform)


def filter_dataset(dataset, classes):
    mask = torch.tensor([y in classes for y in dataset.targets])
    data = dataset.data[mask].unsqueeze(1).float() / 255.0
    targets = dataset.targets[mask]
    class_to_new = {c: i for i, c in enumerate(sorted(classes))}
    new_targets = torch.tensor([class_to_new[t.item()] for t in targets])
    return data, new_targets, class_to_new


def make_loader(data, targets):
    ds = torch.utils.data.TensorDataset(data, targets)
    return torch.utils.data.DataLoader(ds, batch_size=batch_size, shuffle=True)


# ---------------------------
# Автокодировщик для предобучения без учителя (backbone)
# ---------------------------
class Autoencoder(nn.Module):
    def __init__(self, latent_dim=32):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.fc_enc = nn.Linear(32 * 7 * 7, latent_dim)
        self.fc_dec = nn.Linear(latent_dim, 32 * 7 * 7)
        self.decoder = nn.Sequential(nn.ConvTranspose2d(32, 16, 4, 2, 1), nn.ReLU(), nn.ConvTranspose2d(16, 1, 4, 2, 1), nn.Sigmoid())

    def forward(self, x):
        z = self.encoder(x)
        z = z.view(z.size(0), -1)
        z = self.fc_enc(z)

        h = self.fc_dec(z)
        h = h.view(h.size(0), 32, 7, 7)
        x_recon = self.decoder(h)
        return x_recon, z

    def encode(self, x):
        z = self.encoder(x)
        z = z.view(z.size(0), -1)
        z = self.fc_enc(z)
        return z


full_data = train_dataset.data.unsqueeze(1).float() / 255.0
full_targets = train_dataset.targets
full_loader = make_loader(full_data, full_targets)

autoenc = Autoencoder(latent_dim=64).to(device)
optimizer_ae = optim.Adam(autoenc.parameters(), lr=0.001)
criterion_ae = nn.MSELoss()

# Лог для автокодировщика
ae_loss_log = []
for ep in range(epochs):
    autoenc.train()
    running_loss = 0.0
    for d, t in full_loader:
        d = d.to(device)
        optimizer_ae.zero_grad()
        x_recon, z = autoenc(d)
        loss = criterion_ae(x_recon, d)
        loss.backward()
        optimizer_ae.step()
        running_loss += loss.item() * d.size(0)
    ae_loss_log.append(running_loss / len(full_loader.dataset))

# Заморозим энкодер
for p in autoenc.parameters():
    p.requires_grad = False


# ---------------------------
# Attention Classifier
# ---------------------------
class AttentionClassifier(nn.Module):
    def __init__(self, embedding_dim, num_classes):
        super(AttentionClassifier, self).__init__()
        self.queries = nn.Parameter(torch.randn(num_classes, embedding_dim) * 0.1)

    def forward(self, z):
        z_norm = z / (z.norm(dim=1, keepdim=True) + 1e-9)
        q_norm = self.queries / (self.queries.norm(dim=1, keepdim=True) + 1e-9)
        sim = torch.matmul(z_norm, q_norm.t())
        return sim


def train_classifier(model, loader, enc_model, optimizer, criterion):
    model.train()
    enc_model.eval()
    total = 0
    correct = 0
    for data, target in loader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        z = enc_model.encode(data)
        out = model(z)
        loss = criterion(out, target)
        loss.backward()
        optimizer.step()
        pred = out.argmax(dim=1)
        correct += (pred == target).sum().item()
        total += target.size(0)
    return correct / total if total > 0 else 0.0


def test_classifier(model, loader, enc_model):
    model.eval()
    enc_model.eval()
    correct = 0
    total = 0
    pred_all = []
    target_all = []
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            z = enc_model.encode(data)
            out = model(z)
            pred = out.argmax(dim=1)
            correct += (pred == target).sum().item()
            total += target.size(0)
            pred_all.append(pred.cpu())
            target_all.append(target.cpu())
    if total > 0:
        pred_all = torch.cat(pred_all)
        target_all = torch.cat(target_all)
        return (correct / total), pred_all, target_all
    else:
        return 0.0, None, None


criterion_cls = nn.CrossEntropyLoss()

# Мы будем получать навыки для каждого класса i=1..9 от модели {0 vs i}
# Первую модель {0 vs 1} используем для инициализации класса 0 и класса 1.
# Затем для {0 vs 2}, {0 vs 3}, ... извлекаем навыки для других классов.

queries_dict = {}  # dict: class_digit -> query_vector
acc_logs = {}
acc_train_logs = {}

# Сначала модель для {0 vs 1}
train_data_01, train_targets_01, _ = filter_dataset(train_dataset, [0, 1])
test_data_01, test_targets_01, _ = filter_dataset(test_dataset, [0, 1])
train_loader_01 = make_loader(train_data_01, train_targets_01)
test_loader_01 = make_loader(test_data_01, test_targets_01)

model_01 = AttentionClassifier(embedding_dim=64, num_classes=2).to(device)
optimizer_01 = optim.Adam(model_01.parameters(), lr=0.001)

acc_01_log = []
for ep in range(epochs):
    acc_ep = train_classifier(model_01, train_loader_01, autoenc, optimizer_01, criterion_cls)
    acc_01_log.append(acc_ep)
acc_01, _, _ = test_classifier(model_01, test_loader_01, autoenc)

# Сохраняем queries для 0 и 1
with torch.no_grad():
    queries_dict[0] = model_01.queries[0].clone()  # класс 0
    queries_dict[1] = model_01.queries[1].clone()  # класс 1

acc_logs[(0, 1)] = acc_01
acc_train_logs[(0, 1)] = acc_01_log

# Теперь для всех остальных классов {0 vs i}, i=2..9
for i in range(2, 10):
    train_data_i, train_targets_i, _ = filter_dataset(train_dataset, [0, i])
    test_data_i, test_targets_i, _ = filter_dataset(test_dataset, [0, i])
    train_loader_i = make_loader(train_data_i, train_targets_i)
    test_loader_i = make_loader(test_data_i, test_targets_i)

    model_i = AttentionClassifier(embedding_dim=64, num_classes=2).to(device)
    optimizer_i = optim.Adam(model_i.parameters(), lr=0.001)

    acc_i_log = []
    for ep in range(epochs):
        acc_ep = train_classifier(model_i, train_loader_i, autoenc, optimizer_i, criterion_cls)
        acc_i_log.append(acc_ep)
    acc_i, _, _ = test_classifier(model_i, test_loader_i, autoenc)

    # Сохраняем query для класса i (класс 0 уже есть)
    with torch.no_grad():
        # model_i.queries[0] - это класс 0, model_i.queries[1] - это класс i
        queries_dict[i] = model_i.queries[1].clone()

    acc_logs[(0, i)] = acc_i
    acc_train_logs[(0, i)] = acc_i_log


# Теперь у нас есть queries_dict для всех классов 0..9
# Соберём итоговую модель из 10 классов:
class FullAttentionClassifier(nn.Module):
    def __init__(self, embedding_dim, queries_dict):
        super(FullAttentionClassifier, self).__init__()
        # queries_dict - словарь: digit -> vector
        # Соберем в один тензор:
        queries_list = [queries_dict[d] for d in range(10)]
        queries_tensor = torch.stack(queries_list, dim=0)
        # queries - Parameter, чтобы быть в духе, но их можно и зафиксировать
        self.queries = nn.Parameter(queries_tensor)

    def forward(self, z):
        z_norm = z / (z.norm(dim=1, keepdim=True) + 1e-9)
        q_norm = self.queries / (self.queries.norm(dim=1, keepdim=True) + 1e-9)
        sim = torch.matmul(z_norm, q_norm.t())
        return sim


model_full = FullAttentionClassifier(embedding_dim=64, queries_dict=queries_dict).to(device)

# Проверим точность на полном тесте MNIST (все классы 0..9)
test_loader_full = make_loader(test_dataset.data.unsqueeze(1).float() / 255.0, test_dataset.targets)
acc_full, pred_all, target_all = test_classifier(model_full, test_loader_full, autoenc)
error_rate = 1 - acc_full

print("Full model accuracy on all digits (0-9):", acc_full)
print("Error rate:", error_rate)

# Подсчитаем точность по каждому классу отдельно:
if pred_all is not None and target_all is not None:
    class_correct = [0] * 10
    class_total = [0] * 10
    for p, t in zip(pred_all, target_all):
        class_total[t.item()] += 1
        if p.item() == t.item():
            class_correct[t.item()] += 1

    for d in range(10):
        if class_total[d] > 0:
            print(f"Class {d} accuracy: {class_correct[d]/class_total[d]:.4f}")
        else:
            print(f"Class {d} no samples in test?")

# Выведем архитектуру и количество параметров итоговой сети
print("Model architecture:")
print(model_full)
num_params = sum(p.numel() for p in model_full.parameters())
print("Total number of parameters in the final model:", num_params)

# Отобразим графики сходимости автокодировщика и точности обучения пар {0 vs i}
fig = make_subplots(rows=2, cols=2, subplot_titles=("AE Loss", "0 vs 1 Accuracy", "0 vs i Accuracy examples", "Full model accuracy not tracked by epoch but final"))

# AE Loss
fig.add_trace(go.Scatter(y=ae_loss_log, mode="lines+markers", name="AE Loss"), row=1, col=1)

# Accuracy 0 vs 1
fig.add_trace(go.Scatter(y=acc_train_logs[(0, 1)], mode="lines+markers", name="0 vs 1 train acc"), row=1, col=2)

# Покажем для примера {0 vs 2} и {0 vs 9}
if (0, 2) in acc_train_logs:
    fig.add_trace(go.Scatter(y=acc_train_logs[(0, 2)], mode="lines+markers", name="0 vs 2 train acc"), row=2, col=1)
if (0, 9) in acc_train_logs:
    fig.add_trace(go.Scatter(y=acc_train_logs[(0, 9)], mode="lines+markers", name="0 vs 9 train acc"), row=2, col=1)

# У нас нет эпох для full model (она была собрана из кусочков), просто покажем итог:
fig.add_trace(go.Scatter(x=[0], y=[acc_full], mode="markers", name="Full Model Final Acc"), row=2, col=2)

fig.update_layout(height=600, width=900, title_text="Training and Final Model Analysis")
fig.show()

Full model accuracy on all digits (0-9): 0.6931
Error rate: 0.30689999999999995
Class 0 accuracy: 0.9816
Class 1 accuracy: 0.8106
Class 2 accuracy: 0.5494
Class 3 accuracy: 0.7248
Class 4 accuracy: 0.7668
Class 5 accuracy: 0.4271
Class 6 accuracy: 0.6221
Class 7 accuracy: 0.7860
Class 8 accuracy: 0.5893
Class 9 accuracy: 0.6323
Model architecture:
FullAttentionClassifier()
Total number of parameters in the final model: 640


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
import random
from collections import Counter

# Параметры
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 128
epochs = 3
num_classes = 10  # Первые 10 классов
min_samples_per_class = 1000
max_test_samples = 50000

# ---------------------------
# Подготовка датасета ImageNet
# ---------------------------
transform = transforms.Compose(
    [transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]
)

# Загрузка ImageNet
train_dataset = datasets.ImageNet("./data", split="train", download=True, transform=transform)
test_dataset = datasets.ImageNet("./data", split="val", download=True, transform=transform)


def filter_classes(dataset, num_classes, min_samples):
    """
    Фильтрует датасет, оставляя только `num_classes` классов и минимум `min_samples` образцов на класс.
    """
    class_indices = {cls: [] for cls in range(num_classes)}
    for idx, (_, target) in enumerate(dataset):
        if target < num_classes:
            class_indices[target].append(idx)

    # Убедимся, что в каждом классе есть минимум образцов
    for cls in class_indices.keys():
        if len(class_indices[cls]) < min_samples:
            raise ValueError(f"Класс {cls} имеет только {len(class_indices[cls])} образцов, требуется минимум {min_samples}.")

    # Соберём индексы для выборки
    selected_indices = []
    for cls, indices in class_indices.items():
        selected_indices.extend(indices[:min_samples])

    return Subset(dataset, selected_indices)


# Фильтрация обучающей и тестовой выборки
train_subset = filter_classes(train_dataset, num_classes, min_samples_per_class)
test_subset = filter_classes(test_dataset, num_classes, min_samples_per_class)

# Ограничим тестовый набор до 50,000 образцов
if len(test_subset) > max_test_samples:
    test_indices = random.sample(range(len(test_subset)), max_test_samples)
    test_subset = Subset(test_subset, test_indices)

# Даталоадеры
train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_subset, batch_size=batch_size, shuffle=False)


# ---------------------------
# Автокодировщик для предобучения (backbone)
# ---------------------------
class Autoencoder(nn.Module):
    def __init__(self, latent_dim=512):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.fc_enc = nn.Linear(128 * 56 * 56, latent_dim)
        self.fc_dec = nn.Linear(latent_dim, 128 * 56 * 56)
        self.decoder = nn.Sequential(nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.ReLU(), nn.ConvTranspose2d(64, 3, 4, 2, 1), nn.Sigmoid())

    def forward(self, x):
        z = self.encoder(x)
        z = z.view(z.size(0), -1)
        z = self.fc_enc(z)

        h = self.fc_dec(z)
        h = h.view(h.size(0), 128, 56, 56)
        x_recon = self.decoder(h)
        return x_recon, z

    def encode(self, x):
        z = self.encoder(x)
        z = z.view(z.size(0), -1)
        z = self.fc_enc(z)
        return z


# Предобучение автокодировщика
autoenc = Autoencoder(latent_dim=1024).to(device)
optimizer_ae = optim.Adam(autoenc.parameters(), lr=0.001)
criterion_ae = nn.MSELoss()

# Лог для автокодировщика
ae_loss_log = []
for ep in range(epochs):
    autoenc.train()
    running_loss = 0.0
    for d, _ in train_loader:
        d = d.to(device)
        optimizer_ae.zero_grad()
        x_recon, z = autoenc(d)
        loss = criterion_ae(x_recon, d)
        loss.backward()
        optimizer_ae.step()
        running_loss += loss.item() * d.size(0)
    ae_loss_log.append(running_loss / len(train_loader.dataset))

# Заморозим энкодер
for p in autoenc.parameters():
    p.requires_grad = False


# ---------------------------
# Attention Classifier
# ---------------------------
class AttentionClassifier(nn.Module):
    def __init__(self, embedding_dim, num_classes):
        super(AttentionClassifier, self).__init__()
        self.queries = nn.Parameter(torch.randn(num_classes, embedding_dim) * 0.1)

    def forward(self, z):
        z_norm = z / (z.norm(dim=1, keepdim=True) + 1e-9)
        q_norm = self.queries / (self.queries.norm(dim=1, keepdim=True) + 1e-9)
        sim = torch.matmul(z_norm, q_norm.t())
        return sim


# Обучение классификатора
def train_classifier(model, loader, enc_model, optimizer, criterion):
    model.train()
    enc_model.eval()
    total = 0
    correct = 0
    for data, target in loader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        z = enc_model.encode(data)
        out = model(z)
        loss = criterion(out, target)
        loss.backward()
        optimizer.step()
        pred = out.argmax(dim=1)
        correct += (pred == target).sum().item()
        total += target.size(0)
    return correct / total if total > 0 else 0.0


# Тестирование классификатора
def test_classifier(model, loader, enc_model):
    model.eval()
    enc_model.eval()
    correct = 0
    total = 0
    pred_all = []
    target_all = []
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            z = enc_model.encode(data)
            out = model(z)
            pred = out.argmax(dim=1)
            correct += (pred == target).sum().item()
            total += target.size(0)
            pred_all.append(pred.cpu())
            target_all.append(target.cpu())
    if total > 0:
        pred_all = torch.cat(pred_all)
        target_all = torch.cat(target_all)
        return (correct / total), pred_all, target_all
    else:
        return 0.0, None, None


criterion_cls = nn.CrossEntropyLoss()

# ---------------------------
# Итоговая оценка
# ---------------------------
model_full = AttentionClassifier(embedding_dim=1024, num_classes=num_classes).to(device)
optimizer_full = optim.Adam(model_full.parameters(), lr=0.001)

for ep in range(epochs):
    train_acc = train_classifier(model_full, train_loader, autoenc, optimizer_full, criterion_cls)
    print(f"Epoch {ep+1}/{epochs}, Train Accuracy: {train_acc:.2f}")

test_acc, _, _ = test_classifier(model_full, test_loader, autoenc)
print(f"Final Test Accuracy: {test_acc:.2f}")

RuntimeError: The archive ILSVRC2012_devkit_t12.tar.gz is not present in the root directory or is corrupted. You need to download it externally and place it in ./data.