https://stepik.org/lesson/1576210/step/4

In [None]:
from PIL import Image
import torch
import torch.nn as nn
import torchvision.transforms as tfs
# import torchvision.transforms.v2 as tfs_v2 - недоступен на Stepik

# здесь объявляйте класс VAE_CNN
class VAE_CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            # [batch_size, 3, 16, 16]
            nn.Conv2d(3, 16, (3, 3), stride=1, padding=1),
            nn.ELU(inplace=True),
            nn.MaxPool2d((3, 3), stride=2),
            nn.Conv2d(16, 4, (3, 3), stride=1, padding=1),
            nn.ELU(inplace=True),
            nn.MaxPool2d((3, 3), stride=2),
            nn.Flatten()

        )

        self.h_mean = nn.Linear(36, 7)      # [batch_size, 7]
        self.h_log_var = nn.Linear(36, 7)   # [batch_size, 7]

        self.decoder = nn.Sequential(
            # [batch_size, 7]
            nn.Linear(7, 32),
            nn.ReLU(inplace=True),
            nn.Unflatten(1, (2, 4, 4)),
            nn.ConvTranspose2d(2, 8, (2, 2), stride=2, padding=0),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(8, 1, (2, 2), stride=2, padding=0),
            nn.Sigmoid()    # [batch_size, 1, 16, 16]
        )

    def forward(self, x):
        # encoder
        h_mean = self.h_mean(self.encoder(x))
        h_log_var = self.h_log_var(self.encoder(x))

        # Случайные величины с нулевым средним и единичной дисперсией
        noise = torch.normal(mean=torch.zeros_like(h_mean), std=torch.ones_like(h_log_var))
        h = torch.exp(h_log_var / 2) * noise + h_mean

        # decoder
        out = self.decoder(h)

        return out, h_mean, h_log_var

img_pil = Image.new(mode="RGB", size=(64, 78), color=(0, 128, 255))

# здесь продолжайте программу
img_transform = tfs.Compose([
    tfs.CenterCrop(64),
    tfs.Resize(16),
    tfs.ToTensor()
])

img = img_transform(img_pil).unsqueeze(0)  # Добавляем батч-размерность

model = VAE_CNN()
model.eval()

out, hm, hlv = model(img)
out.shape, hm.shape, hlv.shape

(torch.Size([1, 1, 16, 16]), torch.Size([1, 7]), torch.Size([1, 7]))