In [1]:
!pip install -q torch torchvision pandas tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms, models

import numpy as np
import random, pandas as pd
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)


Device: cuda


In [9]:
IMG_SIZE = 32
NUM_CLASSES = 10
BATCH_TRAIN = 128
BATCH_EVAL = 256

# Аугментации для train / нормализация для val/test
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                         std=[0.2470, 0.2435, 0.2616]),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                         std=[0.2470, 0.2435, 0.2616]),
])

full_train = datasets.CIFAR10(root="./data", train=True, download=True,
                              transform=transform_train)
test_dataset = datasets.CIFAR10(root="./data", train=False, download=True,
                                transform=transform_test)

# Разделяем на train/val
train_size = int(0.8 * len(full_train))
val_size = len(full_train) - train_size
train_dataset, val_dataset = random_split(
    full_train,
    [train_size, val_size],
    generator=torch.Generator().manual_seed(SEED)
)

train_loader = DataLoader(train_dataset, batch_size=BATCH_TRAIN,
                          shuffle=True, num_workers=0, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_EVAL,
                        shuffle=False, num_workers=0, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_EVAL,
                         shuffle=False, num_workers=0, pin_memory=True)

len(train_dataset), len(val_dataset), len(test_dataset)


(40000, 10000, 10000)

In [19]:
class PatchEmbed(nn.Module):
    """
    Простое разбиение на патчи через Conv2d.
    32x32 -> (32/4)^2 = 64 токенов.
    """
    def __init__(self, img_size=32, patch_size=4, in_ch=3, embed_dim=128):
        super().__init__()
        self.proj = nn.Conv2d(in_ch, embed_dim,
                              kernel_size=patch_size, stride=patch_size)
        self.num_patches = (img_size // patch_size) ** 2

    def forward(self, x):
        # x: [B,3,H,W] -> [B, N, D]
        x = self.proj(x)           # [B,D,H',W']
        B, D, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)  # [B, N, D]
        return x


In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Stochastic depth (DropPath), как в DeiT/Swin
class DropPath(nn.Module):
    def __init__(self, drop_prob: float = 0.0):
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        if self.drop_prob == 0.0 or not self.training:
            return x
        keep_prob = 1.0 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        random_tensor = keep_prob + torch.rand(
            shape, dtype=x.dtype, device=x.device
        )
        random_tensor.floor_()
        return x / keep_prob * random_tensor


class VSSDBlock(nn.Module):
    """
    Усиленный VSSD-блок:
      - Pre-norm
      - Non-causal global state (взвешенная сумма токенов)
      - Gated fusion глобального состояния с токенами
      - MLP с γ-рестейлингом и DropPath
    """
    def __init__(self, dim, mlp_ratio=4.0, drop=0.0, drop_path=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.gate = nn.Linear(dim, dim)
        self.weight = nn.Linear(dim, 1)
        self.proj_out = nn.Linear(dim, dim)

        self.norm2 = nn.LayerNorm(dim)
        hidden = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, hidden),
            nn.GELU(),
            nn.Dropout(drop),
            nn.Linear(hidden, dim),
            nn.Dropout(drop),
        )

        # γ-масштабы (ResScale) помогают стабильности при большой глубине
        self.gamma1 = nn.Parameter(1e-2 * torch.ones(dim))
        self.gamma2 = nn.Parameter(1e-2 * torch.ones(dim))
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

    def forward(self, x):
        # x: [B, N, D]
        # --- NC-SSD-подобная часть ---
        shortcut = x
        x_norm = self.norm1(x)

        # важность токенов -> softmax по N (non-causal, каждый токен "видит будущее")
        w = self.weight(x_norm).squeeze(-1)       # [B, N]
        alpha = torch.softmax(w, dim=-1)          # [B, N]
        s = torch.einsum("bn,bnd->bd", alpha, x_norm).unsqueeze(1)  # [B,1,D]

        g = torch.sigmoid(self.gate(x_norm))      # [B, N, D]
        y = self.proj_out(g * s)                  # [B, N, D]

        x = shortcut + self.drop_path(self.gamma1 * y)

        # --- MLP ---
        shortcut2 = x
        x_norm2 = self.norm2(x)
        x_mlp = self.mlp(x_norm2)
        x = shortcut2 + self.drop_path(self.gamma2 * x_mlp)
        return x

In [18]:
class TinyVSSD(nn.Module):
    """
    Усиленный TinyVSSD:
      - PatchEmbed
      - CLS-токен + позиционные эмбеддинги
      - stack из VSSDBlock
      - LayerNorm + head по CLS
    """
    def __init__(
        self,
        img_size=32,
        patch_size=4,
        in_ch=3,
        num_classes=10,
        embed_dim=192,
        depth=6,
        mlp_ratio=3.0,
        drop_rate=0.0,
        drop_path_rate=0.1,
    ):
        super().__init__()
        self.patch_embed = PatchEmbed(
            img_size=img_size,
            patch_size=patch_size,
            in_ch=in_ch,
            embed_dim=embed_dim,
        )
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(drop_rate)

        # линейно растущий drop_path по глубине (как в Swin/ConvNeXt)
        dpr = torch.linspace(0, drop_path_rate, depth).tolist()
        self.blocks = nn.ModuleList([
            VSSDBlock(
                dim=embed_dim,
                mlp_ratio=mlp_ratio,
                drop=drop_rate,
                drop_path=dpr[i],
            )
            for i in range(depth)
        ])

        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

        # инициализация
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)

    def forward(self, x):
        # x: [B,3,H,W]
        x = self.patch_embed(x)           # [B, N, D]
        B, N, D = x.shape

        cls_tokens = self.cls_token.expand(B, -1, -1)  # [B,1,D]
        x = torch.cat((cls_tokens, x), dim=1)          # [B,N+1,D]

        x = x + self.pos_embed
        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)

        x = self.norm(x)
        cls = x[:, 0]                      # CLS-токен
        logits = self.head(cls)
        return logits


In [26]:
class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads=3, mlp_ratio=3.0, drop=0.0, drop_path=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(
            embed_dim=dim,
            num_heads=num_heads,
            dropout=drop,
            batch_first=True,
        )
        self.norm2 = nn.LayerNorm(dim)
        hidden = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, hidden),
            nn.GELU(),
            nn.Dropout(drop),
            nn.Linear(hidden, dim),
            nn.Dropout(drop),
        )
        # γ-скейлинг + DropPath, как в VSSD
        self.gamma1 = nn.Parameter(1e-2 * torch.ones(dim))
        self.gamma2 = nn.Parameter(1e-2 * torch.ones(dim))
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

    def forward(self, x):
        # x: [B, N, D]
        shortcut = x
        x_norm = self.norm1(x)
        attn_out, _ = self.attn(x_norm, x_norm, x_norm)
        x = shortcut + self.drop_path(self.gamma1 * attn_out)

        shortcut2 = x
        x_norm2 = self.norm2(x)
        x_mlp = self.mlp(x_norm2)
        x = shortcut2 + self.drop_path(self.gamma2 * x_mlp)
        return x


class TinyViT(nn.Module):
    """
    Лёгкий ViT-аналог для сравнения с TinyVSSD:
    PatchEmbed -> CLS + pos -> TransformerBlock * depth -> CLS head.
    """
    def __init__(
        self,
        img_size=32,
        patch_size=4,
        in_ch=3,
        num_classes=10,
        embed_dim=192,
        depth=6,
        num_heads=3,
        mlp_ratio=3.0,
        drop_rate=0.0,
        drop_path_rate=0.1,
    ):
        super().__init__()
        self.patch_embed = PatchEmbed(
            img_size=img_size,
            patch_size=patch_size,
            in_ch=in_ch,
            embed_dim=embed_dim,
        )
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(drop_rate)

        dpr = torch.linspace(0, drop_path_rate, depth).tolist()
        self.blocks = nn.ModuleList([
            TransformerBlock(
                dim=embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                drop=drop_rate,
                drop_path=dpr[i],
            )
            for i in range(depth)
        ])

        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)

    def forward(self, x):
        x = self.patch_embed(x)         # [B, N, D]
        B, N, D = x.shape

        cls_tokens = self.cls_token.expand(B, -1, -1)  # [B,1,D]
        x = torch.cat((cls_tokens, x), dim=1)          # [B,N+1,D]

        x = x + self.pos_embed
        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)

        x = self.norm(x)
        cls = x[:, 0]
        logits = self.head(cls)
        return logits


In [21]:
def train_one_epoch(model, loader, optimizer, criterion):
    model.train()
    total_loss, correct, total = 0.0, 0, 0

    for images, labels in tqdm(loader, leave=False):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        logits = model(images)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * images.size(0)
        preds = logits.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    return total_loss / total, correct / total


@torch.no_grad()
def evaluate(model, loader, criterion=None):
    model.eval()
    total_loss, correct, total = 0.0, 0, 0

    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        logits = model(images)
        if criterion is not None:
            loss = criterion(logits, labels)
            total_loss += loss.item() * images.size(0)
        preds = logits.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    if criterion is None:
        return None, correct / total
    else:
        return total_loss / total, correct / total


def run_training(model, train_loader, val_loader,
                 epochs=10, lr=3e-4, weight_decay=1e-4):
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr,
                                  weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                           T_max=epochs)

    history = {"train_loss": [], "train_acc": [],
               "val_loss": [], "val_acc": []}

    for ep in range(1, epochs+1):
        print(f"Epoch {ep}/{epochs}")
        train_loss, train_acc = train_one_epoch(model, train_loader,
                                                optimizer, criterion)
        val_loss, val_acc = evaluate(model, val_loader, criterion)

        scheduler.step()

        history["train_loss"].append(train_loss)
        history["train_acc"].append(train_acc)
        history["val_loss"].append(val_loss)
        history["val_acc"].append(val_acc)

        print(f"  train_loss={train_loss:.4f}, train_acc={train_acc:.4f}, "
              f"val_loss={val_loss:.4f}, val_acc={val_acc:.4f}")
    return history


In [31]:
vssd_model = TinyVSSD(
    img_size=IMG_SIZE,
    patch_size=4,
    in_ch=3,
    num_classes=NUM_CLASSES,
    embed_dim=192,
    depth=6,
    mlp_ratio=3.0,
    drop_rate=0.1,
    drop_path_rate=0.1,
).to(device)

vit_model = TinyViT(
    img_size=IMG_SIZE,
    patch_size=4,
    in_ch=3,
    num_classes=NUM_CLASSES,
    embed_dim=192,
    depth=6,
    num_heads=3,
    mlp_ratio=3.0,
    drop_rate=0.1,
    drop_path_rate=0.1,
).to(device)

EPOCHS = 20
LR = 1e-3

print("=== TinyVSSD training ===")
hist_vssd = run_training(vssd_model, train_loader, val_loader,
                         epochs=EPOCHS, lr=LR)

print("\n=== TinyViT training ===")
hist_vit = run_training(vit_model, train_loader, val_loader,
                        epochs=EPOCHS, lr=LR)


=== TinyVSSD training ===
Epoch 1/20


  0%|          | 0/313 [00:00<?, ?it/s]

  train_loss=1.8607, train_acc=0.3021, val_loss=1.6485, val_acc=0.3957
Epoch 2/20


  0%|          | 0/313 [00:00<?, ?it/s]

  train_loss=1.5997, train_acc=0.4139, val_loss=1.4658, val_acc=0.4597
Epoch 3/20


  0%|          | 0/313 [00:00<?, ?it/s]

  train_loss=1.4590, train_acc=0.4709, val_loss=1.3924, val_acc=0.4922
Epoch 4/20


  0%|          | 0/313 [00:00<?, ?it/s]

  train_loss=1.3585, train_acc=0.5065, val_loss=1.2682, val_acc=0.5407
Epoch 5/20


  0%|          | 0/313 [00:00<?, ?it/s]

  train_loss=1.2759, train_acc=0.5355, val_loss=1.2907, val_acc=0.5370
Epoch 6/20


  0%|          | 0/313 [00:00<?, ?it/s]

  train_loss=1.2054, train_acc=0.5625, val_loss=1.1380, val_acc=0.5874
Epoch 7/20


  0%|          | 0/313 [00:00<?, ?it/s]

  train_loss=1.1346, train_acc=0.5890, val_loss=1.0666, val_acc=0.6163
Epoch 8/20


  0%|          | 0/313 [00:00<?, ?it/s]

  train_loss=1.0778, train_acc=0.6103, val_loss=1.0386, val_acc=0.6260
Epoch 9/20


  0%|          | 0/313 [00:00<?, ?it/s]

  train_loss=1.0270, train_acc=0.6262, val_loss=0.9939, val_acc=0.6444
Epoch 10/20


  0%|          | 0/313 [00:00<?, ?it/s]

  train_loss=0.9813, train_acc=0.6436, val_loss=0.9724, val_acc=0.6491
Epoch 11/20


  0%|          | 0/313 [00:00<?, ?it/s]

  train_loss=0.9400, train_acc=0.6603, val_loss=0.9209, val_acc=0.6696
Epoch 12/20


  0%|          | 0/313 [00:00<?, ?it/s]

  train_loss=0.9077, train_acc=0.6724, val_loss=0.8958, val_acc=0.6798
Epoch 13/20


  0%|          | 0/313 [00:00<?, ?it/s]

  train_loss=0.8681, train_acc=0.6884, val_loss=0.8707, val_acc=0.6921
Epoch 14/20


  0%|          | 0/313 [00:00<?, ?it/s]

  train_loss=0.8361, train_acc=0.6986, val_loss=0.8384, val_acc=0.7002
Epoch 15/20


  0%|          | 0/313 [00:00<?, ?it/s]

  train_loss=0.8114, train_acc=0.7088, val_loss=0.8283, val_acc=0.7048
Epoch 16/20


  0%|          | 0/313 [00:00<?, ?it/s]

  train_loss=0.7823, train_acc=0.7210, val_loss=0.8166, val_acc=0.7079
Epoch 17/20


  0%|          | 0/313 [00:00<?, ?it/s]

  train_loss=0.7637, train_acc=0.7264, val_loss=0.8120, val_acc=0.7179
Epoch 18/20


  0%|          | 0/313 [00:00<?, ?it/s]

  train_loss=0.7457, train_acc=0.7316, val_loss=0.8054, val_acc=0.7137
Epoch 19/20


  0%|          | 0/313 [00:00<?, ?it/s]

  train_loss=0.7324, train_acc=0.7346, val_loss=0.7977, val_acc=0.7198
Epoch 20/20


  0%|          | 0/313 [00:00<?, ?it/s]

  train_loss=0.7273, train_acc=0.7399, val_loss=0.7940, val_acc=0.7194

=== TinyViT training ===
Epoch 1/20


  0%|          | 0/313 [00:00<?, ?it/s]

  train_loss=1.7560, train_acc=0.3462, val_loss=1.5002, val_acc=0.4481
Epoch 2/20


  0%|          | 0/313 [00:00<?, ?it/s]

  train_loss=1.4929, train_acc=0.4577, val_loss=1.3715, val_acc=0.5026
Epoch 3/20


  0%|          | 0/313 [00:00<?, ?it/s]

  train_loss=1.3928, train_acc=0.4934, val_loss=1.3026, val_acc=0.5301
Epoch 4/20


  0%|          | 0/313 [00:00<?, ?it/s]

  train_loss=1.3183, train_acc=0.5229, val_loss=1.2592, val_acc=0.5466
Epoch 5/20


  0%|          | 0/313 [00:00<?, ?it/s]

  train_loss=1.2604, train_acc=0.5454, val_loss=1.2081, val_acc=0.5628
Epoch 6/20


  0%|          | 0/313 [00:00<?, ?it/s]

  train_loss=1.2109, train_acc=0.5606, val_loss=1.1557, val_acc=0.5757
Epoch 7/20


  0%|          | 0/313 [00:00<?, ?it/s]

  train_loss=1.1505, train_acc=0.5844, val_loss=1.0867, val_acc=0.6061
Epoch 8/20


  0%|          | 0/313 [00:00<?, ?it/s]

  train_loss=1.1088, train_acc=0.6002, val_loss=1.0310, val_acc=0.6329
Epoch 9/20


  0%|          | 0/313 [00:00<?, ?it/s]

  train_loss=1.0587, train_acc=0.6189, val_loss=1.0209, val_acc=0.6341
Epoch 10/20


  0%|          | 0/313 [00:00<?, ?it/s]

  train_loss=1.0176, train_acc=0.6325, val_loss=0.9686, val_acc=0.6527
Epoch 11/20


  0%|          | 0/313 [00:00<?, ?it/s]

  train_loss=0.9755, train_acc=0.6536, val_loss=0.9409, val_acc=0.6695
Epoch 12/20


  0%|          | 0/313 [00:00<?, ?it/s]

  train_loss=0.9384, train_acc=0.6630, val_loss=0.9113, val_acc=0.6710
Epoch 13/20


  0%|          | 0/313 [00:00<?, ?it/s]

  train_loss=0.9031, train_acc=0.6772, val_loss=0.8744, val_acc=0.6903
Epoch 14/20


  0%|          | 0/313 [00:00<?, ?it/s]

  train_loss=0.8740, train_acc=0.6868, val_loss=0.8517, val_acc=0.6965
Epoch 15/20


  0%|          | 0/313 [00:00<?, ?it/s]

  train_loss=0.8408, train_acc=0.6972, val_loss=0.8329, val_acc=0.7072
Epoch 16/20


  0%|          | 0/313 [00:00<?, ?it/s]

  train_loss=0.8161, train_acc=0.7079, val_loss=0.8128, val_acc=0.7124
Epoch 17/20


  0%|          | 0/313 [00:00<?, ?it/s]

  train_loss=0.7889, train_acc=0.7185, val_loss=0.8040, val_acc=0.7158
Epoch 18/20


  0%|          | 0/313 [00:00<?, ?it/s]

  train_loss=0.7780, train_acc=0.7237, val_loss=0.7897, val_acc=0.7250
Epoch 19/20


  0%|          | 0/313 [00:00<?, ?it/s]

  train_loss=0.7717, train_acc=0.7240, val_loss=0.7887, val_acc=0.7208
Epoch 20/20


  0%|          | 0/313 [00:00<?, ?it/s]

  train_loss=0.7640, train_acc=0.7268, val_loss=0.7867, val_acc=0.7293


In [36]:
criterion = nn.CrossEntropyLoss()

test_loss_vssd, test_acc_vssd = evaluate(vssd_model, test_loader, criterion)
test_loss_vit,  test_acc_vit  = evaluate(vit_model,  test_loader, criterion)

print(f"TinyVSSD  test_loss={test_loss_vssd:.4f}, test_acc={test_acc_vssd:.4f}")
print(f"TinyViT   test_loss={test_loss_vit:.4f},  test_acc={test_acc_vit:.4f}")


TinyVSSD  test_loss=0.7869, test_acc=0.7258
TinyViT   test_loss=0.7580,  test_acc=0.7334


In [37]:
def count_params_m(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6

import pandas as pd

results = pd.DataFrame([
    {
        "model": "TinyVSSD",
        "best_val_acc": max(hist_vssd["val_acc"]),
        "test_acc": test_acc_vssd,
        "params_M": count_params_m(vssd_model),
    },
    {
        "model": "TinyViT",
        "best_val_acc": max(hist_vit["val_acc"]),
        "test_acc": test_acc_vit,
        "params_M": count_params_m(vit_model),
    },
]).round(4)

results


Unnamed: 0,model,best_val_acc,test_acc,params_M
0,TinyVSSD,0.7198,0.7258,1.8088
1,TinyViT,0.7293,0.7334,2.2524


# Выводы
По результатам эксперимента TinyVSSD не улучшает качество по сравнению с аналогом-трансформером, но даёт сопоставимый уровень при меньшем числе параметров.TinyVSSD использует около 1.81M обучаемых параметров против 2.25M у TinyViT (≈-20% параметров). Кривые обучения показывают схожую динамику: обе модели плавно выходят на зону 0.7+ accuracy без заметного переобучения.

