In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class LadderNet(nn.Module):
    def __init__(self, input_shape, num_classes):
        super(LadderNet, self).__init__()
        self.input_shape = input_shape
        self.num_classes = num_classes

        # Encoder layers
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(256)

        # Decoder layers
        self.deconv4 = nn.ConvTranspose2d(256, 128, kernel_size=3, padding=1)
        self.bn5 = nn.BatchNorm2d(128)
        self.deconv3 = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1)
        self.bn6 = nn.BatchNorm2d(64)
        self.deconv2 = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1)
        self.bn7 = nn.BatchNorm2d(32)
        self.deconv1 = nn.ConvTranspose2d(32, 1, kernel_size=4, padding=1)

        # Classification layer
        self.fc = nn.Linear(256, num_classes)

    def forward(self, x_labeled, x_unlabeled):
        # Labeled input
        h = F.relu(self.bn1(self.conv1(x_labeled)))
        h = F.max_pool2d(h, 2, 2)
        h = F.relu(self.bn2(self.conv2(h)))
        h = F.max_pool2d(h, 2, 2)
        h = F.relu(self.bn3(self.conv3(h)))
        h = F.max_pool2d(h, 2, 2)
        h = F.relu(self.bn4(self.conv4(h)))
        h_labeled = F.avg_pool2d(h, h.size()[2:])
        h_labeled = h_labeled.view(-1, 256)

        # Unlabeled input
        u = F.relu(self.bn1(self.conv1(x_unlabeled)))
        u = F.max_pool2d(u, 2, 2)
        u = F.relu(self.bn2(self.conv2(u)))
        u = F.max_pool2d(u, 2, 2)
        u = F.relu(self.bn3(self.conv3(u)))
        u = F.max_pool2d(u, 2, 2)
        u = F.relu(self.bn4(self.conv4(u)))
        h_unlabeled = F.avg_pool2d(u, u.size()[2:])
        h_unlabeled = h_unlabeled.view(-1, 256)
        # Add noise to the hidden representations
        noise_labeled = torch.randn_like(h_labeled)
        noise_unlabeled = torch.randn_like(h_unlabeled)
        h_labeled_noisy = h_labeled + noise_labeled
        h_unlabeled_noisy = h_unlabeled + noise_unlabeled

        # Decoder
        u = F.relu(self.bn5(self.deconv4(h_labeled_noisy.view(-1, 256, 1, 1))))
        u = F.interpolate(u, scale_factor=3, mode='nearest')
        u = F.relu(self.bn6(self.deconv3(u)))
        u = F.interpolate(u, scale_factor=3, mode='nearest')
        u = F.relu(self.bn7(self.deconv2(u)))
        u = F.interpolate(u, scale_factor=3, mode='nearest')
        x_reconstructed = torch.sigmoid(self.deconv1(u))

        # Classification
        output = self.fc(h_labeled)

        return output, x_reconstructed, h_labeled, h_unlabeled_noisy

    def labeled_loss(self, output_labeled, target_labeled):
        criterion = nn.CrossEntropyLoss()
        return criterion(output_labeled, target_labeled)

    def unlabeled_loss(self, x_reconstructed, x_unlabeled, h_labeled, h_unlabeled_noisy):
        # MSE loss between reconstructed and unlabeled input
        mse_loss = nn.MSELoss()
        mse = mse_loss(x_reconstructed, x_unlabeled)
        # KL divergence between hidden representations with noise and without noise
        kl_div = 0.5 * torch.mean(
            torch.sum(h_unlabeled_noisy.pow(2), dim=1) - torch.sum(h_unlabeled_noisy - h_labeled.pow(2),
                                                                   dim=1) + torch.sum(
                torch.log(torch.std(h_labeled, dim=0) / torch.std(h_unlabeled_noisy, dim=0))))

        return mse + kl_div

In [2]:
import torchvision.datasets as datasets
import torchvision.transforms as transforms

# Define the transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Load the MNIST dataset
train_set = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
test_set = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

# Divide the training set into labeled and unlabeled sets
n_labeled = int(0.1 * len(train_set))
n_unlabeled = len(train_set) - n_labeled
train_labeled_set, train_unlabeled_set = torch.utils.data.random_split(train_set, [n_labeled, n_unlabeled])

# Create data loaders for the labeled and unlabeled sets
batch_size = 128

In [3]:

from torch.utils.data import Dataset


class SemisupervisedDataset(Dataset):
    def __init__(self, labeled_dataset, unlabeled_dataset):
        self.labeled_dataset = labeled_dataset
        self.unlabeled_dataset = unlabeled_dataset
        self.unlabeled_size = len(unlabeled_dataset)
        self.labeled_size = len(labeled_dataset)
        print("Dataset created")

    def __len__(self):
        return self.labeled_size

    def __getitem__(self, index):
        labeled_data, labeled_target = self.labeled_dataset[index % self.labeled_size]
        unlabeled_data, _ = self.unlabeled_dataset[index]

        return labeled_data, labeled_target, unlabeled_data

In [4]:
dataset = SemisupervisedDataset(train_labeled_set, train_unlabeled_set)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, pin_memory=True)

Dataset created


In [5]:
# Check if a GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [6]:
from torch import optim

# Define the model
model = LadderNet(input_shape=(1, 28, 28), num_classes=10)

# Move the model to the device
model.to(device)

# Define the loss functions
labeled_loss_fn = model.labeled_loss
unlabeled_loss_fn = model.unlabeled_loss

num_epochs = 10

# Define the optimizer
optimizer = optim.Adam(model.parameters())

In [7]:
from tqdm import tqdm

# Train the model
for epoch in tqdm(range(num_epochs)):
    model.train()
    for data_labeled, target_labeled, data_unlabeled in dataloader:
        # Train the model on the labeled data
        data_labeled = data_labeled.to(device)
        target_labeled = target_labeled.to(device)
        data_unlabeled = data_unlabeled.to(device)

        optimizer.zero_grad()
        output_labeled, x_reconstructed, h_labeled, h_unlabeled_noisy = model(data_labeled, data_unlabeled)
        loss_labeled = labeled_loss_fn(output_labeled, target_labeled)
        loss_labeled.backward()
        optimizer.step()

        # Train the model on the unlabeled data

        optimizer.zero_grad()
        x_unlabeled = data_unlabeled.view(-1, 784)
        x_reconstructed = x_reconstructed.view(-1, 784)
        loss_unlabeled = unlabeled_loss_fn(x_reconstructed, x_unlabeled, h_labeled, h_unlabeled_noisy)
        # loss_unlabeled.backward()
        # optimizer.step()


    # Test the model
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data = data.to(device)
            target = target.to(device)
            output, _, _, _ = model(data, data_unlabeled)
            test_loss += labeled_loss_fn(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    print('Epoch: {} Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(epoch, test_loss, correct, len(test_loader.dataset), accuracy))

 10%|█         | 1/10 [00:08<01:14,  8.32s/it]

Epoch: 0 Test set: Average loss: 0.0051, Accuracy: 7732/10000 (77%)


 20%|██        | 2/10 [00:14<00:55,  6.92s/it]

Epoch: 1 Test set: Average loss: 0.0041, Accuracy: 8113/10000 (81%)


 30%|███       | 3/10 [00:20<00:45,  6.49s/it]

Epoch: 2 Test set: Average loss: 0.0038, Accuracy: 8280/10000 (83%)


 40%|████      | 4/10 [00:26<00:37,  6.25s/it]

Epoch: 3 Test set: Average loss: 0.0032, Accuracy: 8577/10000 (86%)


 50%|█████     | 5/10 [00:32<00:30,  6.12s/it]

Epoch: 4 Test set: Average loss: 0.0030, Accuracy: 8656/10000 (87%)


 60%|██████    | 6/10 [00:37<00:24,  6.06s/it]

Epoch: 5 Test set: Average loss: 0.0045, Accuracy: 8221/10000 (82%)


 70%|███████   | 7/10 [00:43<00:18,  6.05s/it]

Epoch: 6 Test set: Average loss: 0.0034, Accuracy: 8589/10000 (86%)


 80%|████████  | 8/10 [00:49<00:12,  6.02s/it]

Epoch: 7 Test set: Average loss: 0.0039, Accuracy: 8512/10000 (85%)


 90%|█████████ | 9/10 [00:55<00:05,  5.99s/it]

Epoch: 8 Test set: Average loss: 0.0056, Accuracy: 8003/10000 (80%)


100%|██████████| 10/10 [01:01<00:00,  6.18s/it]

Epoch: 9 Test set: Average loss: 0.0047, Accuracy: 8292/10000 (83%)



