<a href="https://colab.research.google.com/github/selcuk-yalcin/TrustworthyML/blob/main/Pseudo_Labeling.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, Subset, ConcatDataset
import numpy as np

In [2]:
# --- 1. Download and prepare CIFAR-10 dataset ---

transform_train = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize for ResNet input
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

transform_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

cifar_train = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
cifar_test = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

In [3]:
# --- 2. Split into labeled and unlabeled datasets ---

labeled_indices = list(range(500))  # Small labeled set (e.g., 500 images)
unlabeled_indices = list(range(500, len(cifar_train)))  # Remaining as unlabeled

labeled_dataset = Subset(cifar_train, labeled_indices)
unlabeled_dataset = Subset(cifar_train, unlabeled_indices)

test_loader = DataLoader(cifar_test, batch_size=128, shuffle=False)
labeled_loader = DataLoader(labeled_dataset, batch_size=32, shuffle=True)
unlabeled_loader = DataLoader(unlabeled_dataset, batch_size=128, shuffle=False)

In [4]:
# --- 3. Load pretrained ResNet-18 model ---

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

resnet = models.resnet18(pretrained=True)

# Modify the final layer for CIFAR-10 (10 classes)
num_ftrs = resnet.fc.in_features
resnet.fc = nn.Linear(num_ftrs, 10)

resnet = resnet.to(device)



In [5]:
# --- 4. Train the model on labeled data only ---

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(resnet.parameters(), lr=0.0005)

def train_model(model, dataloader, epochs=5):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {total_loss:.4f}")

print("Step 1: Training on labeled data...")
train_model(resnet, labeled_loader, epochs=5)

Step 1: Training on labeled data...
Epoch [1/5], Loss: 21.4422
Epoch [2/5], Loss: 6.5647
Epoch [3/5], Loss: 1.9181
Epoch [4/5], Loss: 1.2217
Epoch [5/5], Loss: 0.5993


In [None]:
# --- 5. Generate pseudo labels for unlabeled data ---

print("Step 2: Generating pseudo labels...")
resnet.eval()
pseudo_labels = []

with torch.no_grad():
    for images, _ in unlabeled_loader:
        images = images.to(device)
        outputs = resnet(images)
        predicted = torch.argmax(outputs, dim=1)
        pseudo_labels.append(predicted.cpu())

pseudo_labels = torch.cat(pseudo_labels)

In [None]:
# --- 6. Combine labeled and pseudo-labeled data ---

print("Step 3: Combining labeled and pseudo-labeled data...")

unlabeled_images = torch.stack([unlabeled_dataset[i][0] for i in range(len(unlabeled_dataset))])
pseudo_label_tensor = pseudo_labels

pseudo_labeled_dataset = TensorDataset(unlabeled_images, pseudo_label_tensor)
combined_dataset = ConcatDataset([labeled_dataset, pseudo_labeled_dataset])
combined_loader = DataLoader(combined_dataset, batch_size=32, shuffle=True)

In [None]:
# --- 7. Retrain the model on the combined data ---

print("Step 4: Retraining with combined data...")
train_model(resnet, combined_loader, epochs=5)



In [None]:
# --- 8. Evaluate on the test set ---

def evaluate(model, dataloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            predicted = torch.argmax(outputs, dim=1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
    return correct / total

accuracy = evaluate(resnet, test_loader)
print(f"Step 5: Test Accuracy after pseudo labelling: {accuracy:.4f}")