In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.hub
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

class SEBlock(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SEBlock, self).__init__()
        self.fc1 = nn.Linear(channel, channel // reduction, bias=True)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Linear(channel // reduction, channel, bias=True)
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        batch, channels, _, _ = x.size()
        y = F.adaptive_avg_pool2d(x, 1).view(batch, channels)
        y = self.fc1(y)
        y = self.relu(y)
        y = self.fc2(y)
        y = self.sigmoid(y).view(batch, channels, 1, 1)
        return x * y.expand_as(x)

class UpsampleModule(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super(UpsampleModule, self).__init__()
        self.upsample = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),  # 32 -> 64
            nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            SEBlock(64),
            
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),  # 64 -> 128
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            SEBlock(128),
            
            nn.Upsample(scale_factor=1.75, mode='bilinear', align_corners=True),  # 128 * 1.75 ≈ 224
            nn.Conv2d(128, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            SEBlock(out_channels)
        )
        
    def forward(self, x):
        return self.upsample(x)

In [2]:
class EfficientNetV2_FeatureExtraction(nn.Module):
    def __init__(self, weight_path_small='Weights/efficientnet_v2_s_cifar10.pth', num_classes=10):
        super(EfficientNetV2_FeatureExtraction, self).__init__()
        
        #upsampling module: 32x32 -> 224x224
        self.upsampler = UpsampleModule(in_channels=3, out_channels=3)
        
        model_name_small = 'efficientnet_v2_s'
        self.efficientnet = torch.hub.load('hankyul2/EfficientNetV2-pytorch', model_name_small, nclass=num_classes, skip_validation=True)
        
        self.efficientnet.load_state_dict(torch.load(weight_path_small, map_location=torch.device('cpu')))
    def forward(self, x):
        x = self.upsampler(x)
        x = self.efficientnet(x)
        return x


In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
model = EfficientNetV2_FeatureExtraction(weight_path_small='Weights/efficientnet_v2_s_cifar10.pth', num_classes=10)
model = model.to(device)

Using cache found in /home/ubuntu/.cache/torch/hub/hankyul2_EfficientNetV2-pytorch_main
  self.efficientnet.load_state_dict(torch.load(weight_path_small, map_location=torch.device('cpu')))


In [4]:
from torch.utils.data import DataLoader, random_split

criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

transform = transforms.ToTensor()

full_train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_size = int(0.8 * len(full_train_dataset))
val_size = len(full_train_dataset) - train_size
train_dataset, val_dataset = random_split(full_train_dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42))

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=100, shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False, num_workers=4, pin_memory=True)


Files already downloaded and verified
Files already downloaded and verified
Training samples: 40000
Validation samples: 10000
Test samples: 10000


In [5]:
import torch.optim as optim

optimizer = optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()), 
    lr=0.001, 
    weight_decay=0.005
)

scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer, 
    max_lr=0.001, 
    total_steps=20 * len(train_loader), 
    anneal_strategy='linear',
    pct_start=0.3, 
    div_factor=25, 
    final_div_factor=1e4  
)

In [6]:
import os
from copy import deepcopy

patience = 5
best_val_loss = float('inf')
epochs_no_improve = 0
early_stop = False
best_model_state = deepcopy(model.state_dict())

num_epochs = 20

for epoch in range(num_epochs):
    if early_stop:
        print("Early stopping triggered.")
        break

    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    print(f"\nEpoch {epoch + 1}/{num_epochs}")
    print("-" * 30)

    for batch_idx, (inputs, targets) in enumerate(train_loader, 1):
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, targets)

        loss.backward()
        optimizer.step()
        scheduler.step()

        running_loss += loss.item() * inputs.size(0)
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        print(f"Batch [{batch_idx}/{len(train_loader)}] - Loss: {loss.item():.4f}")

    train_loss = running_loss / total
    train_acc = 100. * correct / total

    # Validation Phase
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0

    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            val_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            val_total += targets.size(0)
            val_correct += predicted.eq(targets).sum().item()

    val_loss /= val_total
    val_acc = 100. * val_correct / val_total

    print(f"Epoch [{epoch + 1}/{num_epochs}] "
          f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% "
          f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        epochs_no_improve = 0
        best_model_state = deepcopy(model.state_dict())
        print(f"Validation loss improved to {val_loss:.4f}.")
    else:
        epochs_no_improve += 1
        print(f"No improvement in validation loss for {epochs_no_improve} epoch(s).")
        if epochs_no_improve >= patience:
            print("Early stopping triggered.")
            early_stop = True

model.load_state_dict(best_model_state)



Epoch 1/20
------------------------------
Batch [1/157] - Loss: 1.0709
Batch [2/157] - Loss: 1.0543
Batch [3/157] - Loss: 1.0435
Batch [4/157] - Loss: 1.0774
Batch [5/157] - Loss: 1.0745
Batch [6/157] - Loss: 1.0018
Batch [7/157] - Loss: 1.0699
Batch [8/157] - Loss: 0.9521
Batch [9/157] - Loss: 1.0037
Batch [10/157] - Loss: 0.8448
Batch [11/157] - Loss: 0.9285
Batch [12/157] - Loss: 0.8719
Batch [13/157] - Loss: 0.9145
Batch [14/157] - Loss: 0.9566
Batch [15/157] - Loss: 0.8715
Batch [16/157] - Loss: 0.9654
Batch [17/157] - Loss: 0.8398
Batch [18/157] - Loss: 0.8433
Batch [19/157] - Loss: 0.7801
Batch [20/157] - Loss: 0.8637
Batch [21/157] - Loss: 0.7886
Batch [22/157] - Loss: 0.7989
Batch [23/157] - Loss: 0.7482
Batch [24/157] - Loss: 0.7111
Batch [25/157] - Loss: 0.7317
Batch [26/157] - Loss: 0.7547
Batch [27/157] - Loss: 0.8056
Batch [28/157] - Loss: 0.8213
Batch [29/157] - Loss: 0.6990
Batch [30/157] - Loss: 0.7720
Batch [31/157] - Loss: 0.7290
Batch [32/157] - Loss: 0.7263
Batch 

<All keys matched successfully>

In [7]:
model.load_state_dict(best_model_state)

model.eval()
test_loss = 0.0
test_correct = 0
test_total = 0
with torch.no_grad():
    for inputs, targets in test_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        
        test_loss += loss.item() * inputs.size(0)
        _, predicted = outputs.max(1)
        test_total += targets.size(0)
        test_correct += predicted.eq(targets).sum().item()

test_loss /= test_total
test_acc = 100. * test_correct / test_total

print(f"Test Loss: {test_loss:.4f} | Test Acc: {test_acc:.2f}%")

Test Loss: 0.5809 | Test Acc: 96.65%
