# ResNet-32 VAE with Perceptual Loss

This notebook demonstrates semi-supervised learning on CIFAR-10 using a ResNet-32 backbone. The approach first trains a Variational Autoencoder (VAE) on the unlabeled portion of the dataset with perceptual loss. The learned encoder is then fine tuned on 500 labeled examples.


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split
import numpy as np


In [None]:

# Data augmentations similar to recent SOTA methods
train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
    transforms.RandomGrayscale(p=0.1),
    transforms.ToTensor(),
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
])

# Load CIFAR-10
full_train = datasets.CIFAR10(root='data', train=True, download=True, transform=train_transform)

# Split 500 labeled examples, rest unlabeled
indices = np.arange(len(full_train))
train_idx, unlabeled_idx = train_test_split(indices, train_size=500, stratify=full_train.targets, random_state=42)
labeled_set = Subset(full_train, train_idx)
unlabeled_set = Subset(full_train, unlabeled_idx)

val_set = datasets.CIFAR10(root='data', train=False, download=True, transform=test_transform)

labeled_loader = DataLoader(labeled_set, batch_size=64, shuffle=True, num_workers=2)
unlabeled_loader = DataLoader(unlabeled_set, batch_size=128, shuffle=True, num_workers=2)
val_loader = DataLoader(val_set, batch_size=128, shuffle=False, num_workers=2)


In [None]:

# Simple ResNet-32 implementation for CIFAR-10
class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_planes, planes, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes)
            )
    def forward(self, x):
        out = nn.ReLU()(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = nn.ReLU()(out)
        return out

def make_layer(in_planes, planes, num_blocks, stride):
    layers = [BasicBlock(in_planes, planes, stride)]
    for _ in range(1, num_blocks):
        layers.append(BasicBlock(planes, planes))
    return nn.Sequential(*layers)

class ResNet32(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.layer1 = make_layer(16, 16, num_blocks=5, stride=1)
        self.layer2 = make_layer(16, 32, num_blocks=5, stride=2)
        self.layer3 = make_layer(32, 64, num_blocks=5, stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(64, num_classes)
    def forward(self, x):
        out = nn.ReLU()(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.avgpool(out)
        out = torch.flatten(out, 1)
        return self.fc(out)


In [None]:

class ResNet32VAE(nn.Module):
    def __init__(self, latent_dim=128):
        super().__init__()
        backbone = ResNet32(num_classes=latent_dim*2)
        self.encoder = nn.Sequential(backbone.conv1, backbone.bn1, nn.ReLU(),
                                     backbone.layer1, backbone.layer2, backbone.layer3, backbone.avgpool)
        self.fc_mu = nn.Linear(64, latent_dim)
        self.fc_logvar = nn.Linear(64, latent_dim)
        self.decoder_fc = nn.Linear(latent_dim, 64)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 3, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid()
        )
        vgg = models.vgg16(pretrained=True)
        self.vgg_features = nn.Sequential(*list(vgg.features)[:16])
        for p in self.vgg_features.parameters():
            p.requires_grad = False
    def encode(self, x):
        h = self.encoder(x)
        h = torch.flatten(h, 1)
        return self.fc_mu(h), self.fc_logvar(h)
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std
    def decode(self, z):
        h = self.decoder_fc(z)
        h = h.view(-1, 64, 1, 1)
        return self.decoder(h)
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon = self.decode(z)
        return recon, mu, logvar

def perceptual_loss(x, recon, vgg):
    with torch.no_grad():
        f_x = vgg(x)
        f_recon = vgg(recon)
    return nn.functional.l1_loss(f_recon, f_x)


In [None]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ResNet32VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

num_epochs = 50
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for data,_ in unlabeled_loader:
        data = data.to(device)
        optimizer.zero_grad()
        recon, mu, logvar = model(data)
        rec_loss = nn.functional.mse_loss(recon, data)
        p_loss = perceptual_loss(data, recon, model.vgg_features)
        kld = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
        loss = rec_loss + 0.1*p_loss + 1e-3*kld
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f'Epoch {epoch+1}: VAE loss {total_loss/len(unlabeled_loader):.4f}')


In [None]:

encoder = model.encoder
classifier = ResNet32(num_classes=10)
classifier.conv1 = encoder[0]
classifier.bn1 = encoder[1]
classifier.layer1 = encoder[3]
classifier.layer2 = encoder[4]
classifier.layer3 = encoder[5]
classifier.avgpool = encoder[6]
classifier.fc = nn.Linear(64, 10)
classifier = classifier.to(device)

optimizer_c = optim.Adam(classifier.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

for epoch in range(30):
    classifier.train()
    total, correct = 0, 0
    for data, targets in labeled_loader:
        data, targets = data.to(device), targets.to(device)
        optimizer_c.zero_grad()
        out = classifier(data)
        loss = criterion(out, targets)
        loss.backward()
        optimizer_c.step()
        pred = out.argmax(dim=1)
        total += targets.size(0)
        correct += pred.eq(targets).sum().item()
    acc = 100.0 * correct / total
    print(f'Epoch {epoch+1}: accuracy {acc:.2f}%')
