In [None]:
import torch
import wandb

from default_mnist_config import create_default_mnist_config
# from diffusion import DiffusionRunner
from models.classifier import ResNet, ResidualBlock, ConditionalResNet
from data_generator import DataGenerator

from tqdm.auto import trange

import os

In [None]:
device = torch.device('cuda')
classifier_args = {
    "block": ResidualBlock,
    "layers": [2, 2, 2, 2]
}
model = ResNet(**classifier_args)
model.to(device)

optim = torch.optim.AdamW(model.parameters(), lr=1e-3)
loss_func = torch.nn.CrossEntropyLoss()

In [None]:
datagen = DataGenerator(create_default_mnist_config())
train_generator = datagen.sample_train()

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 84288815.77it/s]


Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 108447353.47it/s]


Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 21204634.05it/s]


Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 5193710.13it/s]


Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw



In [None]:
TOTAL_ITERS = 2_000
EVAL_FREQ = 500

### Обучите классификатор только на чистых картинках. Он понадобится нам для классификации условно сгенерированных картинок

In [None]:
model.train()

for iter_idx in trange(1, 1 + TOTAL_ITERS):
    
    """
    train
    """
    x, y = next(train_generator)
    x = x.to(device)
    y = y.to(device)

    loss = loss_func(model(x), y)

    optim.zero_grad()
    loss.backward()
    optim.step()
            
    if iter_idx % EVAL_FREQ == 0:
        valid_accuracy = 0
        valid_count = 0
        """
        validate
        """
        model.eval()
        
        with torch.no_grad():
            for (x, y) in datagen.valid_loader:
                x = x.to(device)
                y = y.to(device)

                pred_labels = model(x).max(dim=1)
                pred_labels = pred_labels[1]
                
                valid_accuracy += (y == pred_labels).sum().item()
                valid_count += x.size()[0]
                
        valid_accuracy = valid_accuracy / valid_count
                        
        print('Clean MNIST classifier\'s accuracy:', valid_accuracy)
model.eval()

  0%|          | 0/2000 [00:00<?, ?it/s]

Clean MNIST classifier's accuracy: 0.9902
Clean MNIST classifier's accuracy: 0.8389
Clean MNIST classifier's accuracy: 0.9628
Clean MNIST classifier's accuracy: 0.9773


ResNet(
  (conv): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (layer1): Sequential(
    (0): ResidualBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): ResidualBlock(
      (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=

In [15]:
torch.save(model.state_dict(), './ddpm_checkpoints/clean_classifier.pth')