In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import datasets, transforms
import torchvision.models as models
from torchvision.transforms import RandAugment
import numpy as np
import random
import os
from typing import List, Tuple, Callable
from PIL import Image

In [None]:
NUM_CLASSES: int = 10
BATCH_SIZE: int = 64
MU: int = 7
LR: float = 0.03
MOMENTUM: float = 0.9
WEIGHT_DECAY: float = 5e-4

LAMBDA_U: float = 1.0
T: float = 0.95

SEED: int = 42
DATA_DIR: str = './data'
SAVE_DIR: str = './experiments'

SAVE_N_EPOCHS: int = 20
TOTAL_EPOCHS: int = 100

In [None]:
class CIFAR10SemiSupervised(Dataset):
    def __init__(self, base_dataset: Dataset, transform: Callable):
        self.base_dataset = base_dataset
        self.transform = transform

    def __len__(self) -> int:
        return len(self.base_dataset)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
        img, label = self.base_dataset[idx]
        return self.transform(img), label


class TwoCropsTransform:
    def __init__(
        self,
        base_transform: Callable,
        weak_post_transform: Callable,
        strong_post_transform: Callable,
    ):
        self.base_transform = base_transform
        self.weak_post_transform = weak_post_transform
        self.strong_post_transform = strong_post_transform

    def __call__(self, x: Image.Image) -> List[torch.Tensor]:
        base_img = self.base_transform(x)

        img_w = self.weak_post_transform(base_img)
        img_s = self.strong_post_transform(base_img)

        return [img_w, img_s]


def get_dataloaders(num_labeled_per_class: int):
    cifar10_mean = (0.4914, 0.4822, 0.4465)
    cifar10_std = (0.2471, 0.2435, 0.2616)

    base_spatial_transform = transforms.Compose(
        [
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
        ]
    )

    weak_post_transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize(cifar10_mean, cifar10_std),
        ]
    )

    strong_post_transform = transforms.Compose(
        [
            RandAugment(num_ops=2),
            transforms.ToTensor(),
            transforms.Normalize(cifar10_mean, cifar10_std),
            # CORREÇÃO 1: Simular Cutout (quadrado) com ratio=(1.0, 1.0) e 1/5 da imagem
            transforms.RandomErasing(
                p=0.5, scale=(0.2, 0.2), ratio=(1.0, 1.0), value=0
            ),
        ]
    )

    labeled_transform = transforms.Compose(
        [base_spatial_transform, weak_post_transform]
    )

    test_transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize(cifar10_mean, cifar10_std),
        ]
    )

    train_data = datasets.CIFAR10(
        DATA_DIR, train=True, download=True, transform=None
    )

    test_data = datasets.CIFAR10(
        DATA_DIR, train=False, download=True, transform=test_transform
    )

    targets = np.array(train_data.targets)
    labeled_indices = []
    unlabeled_indices = []

    for i in range(NUM_CLASSES):
        indices = np.where(targets == i)[0]
        np.random.shuffle(indices)

        labeled_indices.extend(indices[:num_labeled_per_class])
        unlabeled_indices.extend(indices[num_labeled_per_class:])

    random.shuffle(labeled_indices)
    random.shuffle(unlabeled_indices)

    print(f"Total de amostras: {len(targets)}")
    print(f"Amostras rotuladas: {len(labeled_indices)}")
    print(f"Amostras não rotuladas: {len(unlabeled_indices)}")

    labeled_dataset = CIFAR10SemiSupervised(
        Subset(train_data, labeled_indices), transform=labeled_transform
    )

    unlabeled_dataset = CIFAR10SemiSupervised(
        Subset(train_data, unlabeled_indices),
        transform=TwoCropsTransform(
            base_spatial_transform, weak_post_transform, strong_post_transform
        ),
    )

    labeled_batch_size = min(BATCH_SIZE, len(labeled_indices))
    if labeled_batch_size == 0:
        labeled_batch_size = 1

    unlabeled_batch_size = BATCH_SIZE * MU

    labeled_loader = DataLoader(
        labeled_dataset,
        batch_size=labeled_batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=2,
        pin_memory=True,
    )

    unlabeled_loader = DataLoader(
        unlabeled_dataset,
        batch_size=unlabeled_batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=2,
        pin_memory=True,
    )

    test_loader = DataLoader(
        test_data,
        batch_size=100,
        shuffle=False,
        num_workers=2,
        pin_memory=True,
    )

    return labeled_loader, unlabeled_loader, test_loader

In [None]:
def get_resnet18_for_cifar(num_classes: int = NUM_CLASSES) -> nn.Module:
    model = models.resnet18(weights=None, num_classes=num_classes)
    model.conv1 = nn.Conv2d(
        3, 64, kernel_size=3, stride=1, padding=1, bias=False
    )
    model.maxpool = nn.Identity()
    return model


def get_resnet50_for_cifar(num_classes: int = NUM_CLASSES) -> nn.Module:
    model = models.resnet50(weights=None, num_classes=num_classes)
    model.conv1 = nn.Conv2d(
        3, 64, kernel_size=3, stride=1, padding=1, bias=False
    )
    model.maxpool = nn.Identity()
    return model


def fixmatch_loss(
    logits_x: torch.Tensor,
    targets_x: torch.Tensor,
    logits_u_w: torch.Tensor,
    logits_u_s: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:

    loss_x = nn.CrossEntropyLoss()(logits_x, targets_x)

    with torch.no_grad():
        probs_u_w = torch.softmax(logits_u_w, dim=1)
        max_probs, pseudo_label = torch.max(probs_u_w, dim=1)
        mask = (max_probs >= T).float()

    loss_u_all = nn.CrossEntropyLoss(reduction='none')(
        logits_u_s, pseudo_label
    )
    loss_u = (loss_u_all * mask).mean()

    return loss_x, LAMBDA_U * loss_u

In [None]:
def train_one_epoch(
    model: nn.Module,
    labeled_loader: DataLoader,
    unlabeled_loader: DataLoader,
    optimizer: optim.Optimizer,
    device: torch.device,
) -> Tuple[float, float, float, int]:

    model.train()
    total_loss = 0
    total_loss_x = 0
    total_loss_u = 0

    labeled_iter = iter(labeled_loader)
    num_batches_epoch = len(unlabeled_loader)

    for i, (batch_unlabeled) in enumerate(unlabeled_loader):
        try:
            (x_batch, targets_x_batch) = next(labeled_iter)
        except StopIteration:
            labeled_iter = iter(labeled_loader)
            (x_batch, targets_x_batch) = next(labeled_iter)

        x_batch = x_batch.to(device)
        targets_x_batch = targets_x_batch.to(device)

        # O batch não rotulado é uma lista [weak_img, strong_img]
        u_w_batch = batch_unlabeled[0][0].to(device)
        u_s_batch = batch_unlabeled[0][1].to(device)

        inputs = torch.cat((x_batch, u_w_batch, u_s_batch))
        logits = model(inputs)

        logits_x = logits[: x_batch.size(0)]
        logits_u_w, logits_u_s = logits[x_batch.size(0) :].chunk(2)

        loss_x, loss_u = fixmatch_loss(
            logits_x, targets_x_batch, logits_u_w, logits_u_s
        )
        loss = loss_x + loss_u

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_loss_x += loss_x.item()
        total_loss_u += loss_u.item()

    avg_loss = total_loss / num_batches_epoch
    avg_loss_x = total_loss_x / num_batches_epoch
    avg_loss_u = total_loss_u / num_batches_epoch

    return avg_loss, avg_loss_x, avg_loss_u


def validate_model(
    model: nn.Module, test_loader: DataLoader, device: torch.device
) -> float:
    model.eval()
    correct = 0
    total = 0
    loss = 0.0
    criterion = nn.CrossEntropyLoss()

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            loss += criterion(outputs, labels).item() * labels.size(0)

    loss /= total
    accuracy = 100 * correct / total
    return accuracy, loss

In [None]:
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def train_model_with_fixmatch(
    num_labeled_per_class: int, resnet_50: bool = False
) -> float:
    set_seed(SEED)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    experiment_name = f"FixMatch_{num_labeled_per_class}_labels_per_class"

    labeled_loader, unlabeled_loader, test_loader = get_dataloaders(
        num_labeled_per_class
    )

    if resnet_50:
        model = get_resnet50_for_cifar(NUM_CLASSES).to(device)
        experiment_name += "_ResNet50"
    else:
        model = get_resnet18_for_cifar(NUM_CLASSES).to(device)

    print(f"Iniciando: {experiment_name}")
    print(f"Usando device: {device}")

    optimizer = optim.SGD(
        model.parameters(),
        lr=LR,
        momentum=MOMENTUM,
        weight_decay=WEIGHT_DECAY,
    )

    best_acc = 0.0

    exp_dir = os.path.join(SAVE_DIR, experiment_name)
    os.makedirs(exp_dir, exist_ok=True)

    with open(os.path.join(exp_dir, 'training_logs.csv'), 'w') as f:
        f.write("epoch,loss,loss_x,loss_u,test_acc,test_loss\n")

    START_EPOCH = 1

    try:
        checkpoint = torch.load(os.path.join(exp_dir, 'best_model.pth'))
        model.load_state_dict(checkpoint['model_state_dict'])
        best_acc = checkpoint['best_val_acc']
        START_EPOCH = checkpoint['epoch'] + 1
        print(f"Carregado checkpoint do epoch {checkpoint['epoch']}")
    except FileNotFoundError:
        print("Nenhum checkpoint encontrado, iniciando treinamento do zero.")

    for epoch in range(START_EPOCH, TOTAL_EPOCHS + 1):

        train_loss, train_loss_x, train_loss_u = train_one_epoch(
            model, labeled_loader, unlabeled_loader, optimizer, device
        )

        test_acc, test_ce_loss = validate_model(model, test_loader, device)

        print(
            f"Epoch {epoch:03d}/{TOTAL_EPOCHS:03d} | "
            f"Loss: {train_loss:.5f} (Lx: {train_loss_x:.5f}, Lu: {train_loss_u:.5f}) | "
            f"Test Acc: {test_acc:.2f}%"
        )

        with open(os.path.join(exp_dir, 'training_logs.csv'), 'a') as f:
            f.write(
                f"{epoch},{train_loss:.5f},{train_loss_x:.5f},"
                f"{train_loss_u:.5f},{test_acc:.2f},{test_ce_loss:.5f}\n"
            )

        if test_acc > best_acc:
            best_acc = test_acc
            torch.save(
                {
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'best_val_acc': best_acc,
                },
                os.path.join(exp_dir, 'best_model.pth'),
            )
            print(f"\tModelo Salvo com acc: {best_acc:.2f}%")

        if epoch % SAVE_N_EPOCHS == 0:
            torch.save(
                model.state_dict(),
                os.path.join(exp_dir, f'model_epoch_{epoch}.pth'),
            )

    print(
        f"Treinamento de {experiment_name} concluído. "
        f"Melhor Acurácia: {best_acc:.2f}%"
    )
    return best_acc

In [None]:
def main():
    set_seed(SEED)
    labeled_per_class_cases = [
        1,  # Caso 1: 10 rótulos total
        4,  # Caso 2: 40 rótulos total
        25,  # Caso 3: 250 rótulos total
        400,  # Caso 4: 4.000 rótulos total
    ]

    labeled_per_class_cases.append(250)

    labeled_per_class_cases.sort()

    results = dict()

    for num_labeled in labeled_per_class_cases:
        acc_resnet50 = train_model_with_fixmatch(num_labeled, resnet_50=True)
        results[f"{num_labeled} rótulos/classe (ResNet50)"] = acc_resnet50

        acc_resnet18 = train_model_with_fixmatch(num_labeled, resnet_50=False)
        results[f"{num_labeled} rótulos/classe (ResNet18)"] = acc_resnet18

    print()
    print("#" * 50)
    print()
    print("Resultados:")
    for case, acc in results.items():
        print(f"Caso {case}: Melhor Acurácia de Teste: {acc:.2f}%")
    print()
    print("#" * 50)

In [None]:
main()