In [2]:
!pip install torchattacks



In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms
import numpy as np
from collections import defaultdict
import random
import time

torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Available GPUs: {torch.cuda.device_count()}")


# ==================== ResNet Building Blocks ====================
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, 
                               stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels * self.expansion:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels * self.expansion,
                         kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels * self.expansion)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNetEmbedding(nn.Module):
    def __init__(self, block, num_blocks, embedding_dim=128, dropout=0.1):
        super(ResNetEmbedding, self).__init__()
        self.in_channels = 64
        
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.dropout = nn.Dropout2d(dropout)
        
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, embedding_dim)
        self.bn_emb = nn.BatchNorm1d(embedding_dim)
        
    def _make_layer(self, block, out_channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride))
            self.in_channels = out_channels * block.expansion
        return nn.Sequential(*layers)
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.dropout(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avg_pool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        out = self.bn_emb(out)
        return F.normalize(out, p=2, dim=1)


def ResNet18Embedding(embedding_dim=128, dropout=0.1):
    return ResNetEmbedding(BasicBlock, [2, 2, 2, 2], embedding_dim, dropout)


class TripletNetwork(nn.Module):
    def __init__(self, embedding_net):
        super(TripletNetwork, self).__init__()
        self.embedding_net = embedding_net
        
    def forward(self, anchor, positive, negative):
        anchor_embedding = self.embedding_net(anchor)
        positive_embedding = self.embedding_net(positive)
        negative_embedding = self.embedding_net(negative)
        return anchor_embedding, positive_embedding, negative_embedding
    
    def get_embedding(self, x):
        return self.embedding_net(x)


def get_embedding_safe(model, x):
    """Safely get embedding from model (handles DataParallel)"""
    if isinstance(model, nn.DataParallel):
        return model.module.get_embedding(x)
    else:
        return model.get_embedding(x)


# ==================== Optimized ART-AL Loss ====================
class ARTALLoss(nn.Module):
    def __init__(self, m1=0.5, m2=2.0, lambda_min=0.2, lambda_max=0.8):
        super(ARTALLoss, self).__init__()
        self.m1 = m1
        self.m2 = m2
        self.lambda_min = lambda_min
        self.lambda_max = lambda_max
        self.lambda_k = 0.5
        self.gamma_k = 0.0
        
    def forward(self, anchor, positive, negative, anchor_adv=None, positive_adv=None):
        pos_dist = torch.sum((anchor - positive) ** 2, dim=1)
        neg_dist = torch.sum((anchor - negative) ** 2, dim=1)
        
        pos_dist = torch.clamp(pos_dist, min=0.0, max=4.0)
        neg_dist = torch.clamp(neg_dist, min=0.0, max=4.0)
        
        intra_loss = torch.clamp(pos_dist - self.m1, min=0.0)
        inter_loss = torch.clamp(self.m2 - neg_dist, min=0.0)
        
        intra_loss_mean = torch.mean(intra_loss)
        inter_loss_mean = torch.mean(inter_loss)
        
        if anchor_adv is not None and positive_adv is not None:
            adv_anchor_dist = torch.sum((anchor - anchor_adv) ** 2, dim=1)
            adv_positive_dist = torch.sum((positive - positive_adv) ** 2, dim=1)
            adv_loss = torch.mean(torch.clamp(adv_anchor_dist + adv_positive_dist, max=2.0))
            self.gamma_k = min(0.3, self.gamma_k + 0.015)
        else:
            adv_loss = torch.tensor(0.0).to(anchor.device)
            self.gamma_k = 0.0
        
        total_base = intra_loss_mean.item() + inter_loss_mean.item()
        if total_base > 1e-6:
            new_lambda = intra_loss_mean.item() / (total_base + 1e-8)
            self.lambda_k = 0.8 * self.lambda_k + 0.2 * new_lambda
            self.lambda_k = max(self.lambda_min, min(self.lambda_max, self.lambda_k))
        
        base_weight = 1.0 - self.gamma_k
        loss = base_weight * ((1 - self.lambda_k) * intra_loss_mean + 
                             self.lambda_k * inter_loss_mean) + self.gamma_k * adv_loss
        
        return loss, intra_loss_mean, inter_loss_mean, adv_loss, self.lambda_k, self.gamma_k


# ==================== FIXED PGD Attack (Works in both train and eval) ====================
def generate_pgd_attack(model, images, epsilon=8/255, alpha=2/255, num_iter=5):
    """
    PGD attack that works correctly in both training and evaluation contexts
    """
    # Store original training state
    was_training = model.training
    model.eval()  # Always use eval mode for PGD attack
    
    # Clone images and enable gradients
    images = images.detach().clone()
    images.requires_grad = True
    
    # Initialize perturbation
    delta = torch.zeros_like(images).uniform_(-epsilon, epsilon)
    delta.requires_grad = True
    
    for _ in range(num_iter):
        # Forward pass with perturbation
        perturbed = images + delta
        embeddings = get_embedding_safe(model, perturbed)
        
        # Loss: maximize embedding magnitude (simple adversarial objective)
        loss = -torch.mean(torch.sum(embeddings ** 2, dim=1))
        
        # Compute gradient
        loss.backward()
        
        # Update perturbation
        with torch.no_grad():
            grad_sign = delta.grad.sign()
            delta.data = delta.data + alpha * grad_sign
            delta.data = torch.clamp(delta.data, -epsilon, epsilon)
            delta.data = torch.clamp(images.data + delta.data, 0, 1) - images.data
        
        # Zero gradients for next iteration
        delta.grad.zero_()
    
    # Restore training state
    if was_training:
        model.train()
    
    # Return adversarial images (detached)
    return (images + delta).detach()


# ==================== Triplet Dataset ====================
class TripletDataset(Dataset):
    def __init__(self, dataset, train=True):
        self.dataset = dataset
        self.train = train
        self.labels = np.array([label for _, label in dataset])
        self.embeddings = None
        
        self.label_to_indices = defaultdict(list)
        for idx, label in enumerate(self.labels):
            self.label_to_indices[label].append(idx)
        
        self.labels_set = set(self.labels)
        
    def update_embeddings(self, embeddings):
        self.embeddings = embeddings.detach().cpu()
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, index):
        anchor_img, anchor_label = self.dataset[index]
        
        positive_indices = [idx for idx in self.label_to_indices[anchor_label] if idx != index]
        if len(positive_indices) == 0:
            positive_index = index
        else:
            if self.embeddings is not None and len(positive_indices) > 1:
                anchor_emb = self.embeddings[index]
                pos_embs = self.embeddings[positive_indices]
                distances = torch.sum((anchor_emb - pos_embs) ** 2, dim=1)
                k = max(1, int(len(distances) * 0.5))
                topk_indices = torch.topk(distances, k, largest=True)[1]
                selected = topk_indices[torch.randint(len(topk_indices), (1,))].item()
                positive_index = positive_indices[selected]
            else:
                positive_index = np.random.choice(positive_indices)
        
        positive_img, _ = self.dataset[positive_index]
        
        if self.embeddings is not None:
            anchor_emb = self.embeddings[index]
            negative_candidates = []
            for neg_label in (self.labels_set - {anchor_label}):
                neg_indices = self.label_to_indices[neg_label]
                sampled = np.random.choice(neg_indices, min(5, len(neg_indices)), replace=False)
                negative_candidates.extend(sampled)
            
            if len(negative_candidates) > 0:
                neg_embs = self.embeddings[negative_candidates]
                distances = torch.sum((anchor_emb - neg_embs) ** 2, dim=1)
                k = max(1, int(len(distances) * 0.3))
                topk_indices = torch.topk(distances, k, largest=False)[1]
                selected_idx = topk_indices[torch.randint(len(topk_indices), (1,))].item()
                negative_index = negative_candidates[selected_idx]
            else:
                negative_label = np.random.choice(list(self.labels_set - {anchor_label}))
                negative_index = np.random.choice(self.label_to_indices[negative_label])
        else:
            negative_label = np.random.choice(list(self.labels_set - {anchor_label}))
            negative_index = np.random.choice(self.label_to_indices[negative_label])
        
        negative_img, _ = self.dataset[negative_index]
        return anchor_img, positive_img, negative_img, anchor_label


class PredictionNetwork(nn.Module):
    def __init__(self, embedding_dim=128, num_classes=10):
        super(PredictionNetwork, self).__init__()
        self.fc1 = nn.Linear(embedding_dim, 256)
        self.bn1 = nn.BatchNorm1d(256)
        self.dropout1 = nn.Dropout(0.3)
        self.fc2 = nn.Linear(256, num_classes)
        
    def forward(self, x):
        x = F.relu(self.bn1(self.fc1(x)))
        x = self.dropout1(x)
        x = self.fc2(x)
        return x


# ==================== Training Functions ====================
def train_triplet_network(model, criterion, optimizer, train_loader, epoch, use_adv=False):
    model.train()
    total_loss = 0
    total_intra = 0
    total_inter = 0
    total_adv = 0
    
    for batch_idx, (anchor, positive, negative, labels) in enumerate(train_loader):
        anchor = anchor.to(device)
        positive = positive.to(device)
        negative = negative.to(device)
        
        optimizer.zero_grad()
        
        anchor_emb, positive_emb, negative_emb = model(anchor, positive, negative)
        
        # Adversarial samples (progressive schedule)
        if use_adv and epoch > 12:
            epsilon = 0.008 if epoch < 25 else 0.015 if epoch < 38 else 8/255
            anchor_adv = generate_pgd_attack(model, anchor, epsilon=epsilon, num_iter=5)
            positive_adv = generate_pgd_attack(model, positive, epsilon=epsilon, num_iter=5)
            
            anchor_adv_emb = get_embedding_safe(model, anchor_adv)
            positive_adv_emb = get_embedding_safe(model, positive_adv)
        else:
            anchor_adv_emb = None
            positive_adv_emb = None
        
        loss, intra_loss, inter_loss, adv_loss, lambda_k, gamma_k = criterion(
            anchor_emb, positive_emb, negative_emb, anchor_adv_emb, positive_adv_emb
        )
        
        if torch.isnan(loss) or torch.isinf(loss):
            print(f"Warning: NaN/Inf at batch {batch_idx}, skipping...")
            continue
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        
        total_loss += loss.item()
        total_intra += intra_loss.item()
        total_inter += inter_loss.item()
        total_adv += adv_loss.item()
        
        if batch_idx % 75 == 0:
            print(f'[{batch_idx:3d}/{len(train_loader)}] '
                  f'L: {loss.item():.3f} | In: {intra_loss.item():.3f} | '
                  f'Int: {inter_loss.item():.3f} | Adv: {adv_loss.item():.3f} | '
                  f'Î»: {lambda_k:.2f} | Î³: {gamma_k:.2f}')
    
    return (total_loss / len(train_loader), total_intra / len(train_loader), 
            total_inter / len(train_loader), total_adv / len(train_loader))


def train_prediction_network(triplet_model, pred_network, optimizer, train_loader, epoch, use_adv=False):
    triplet_model.eval()
    pred_network.train()
    total_loss = 0
    correct = 0
    total = 0
    
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        
        if use_adv and epoch > 8 and torch.rand(1).item() < 0.4:
            images = generate_pgd_attack(triplet_model, images, epsilon=8/255, num_iter=10)
        
        optimizer.zero_grad()
        
        with torch.no_grad():
            embeddings = get_embedding_safe(triplet_model, images)
        
        outputs = pred_network(embeddings)
        loss = criterion(outputs, labels)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(pred_network.parameters(), max_norm=1.0)
        optimizer.step()
        
        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    
    return total_loss / len(train_loader), 100. * correct / total


def test_accuracy(triplet_model, pred_network, test_loader, use_pgd=False):
    """FIXED: PGD attack now works during evaluation"""
    triplet_model.eval()
    pred_network.eval()
    correct = 0
    total = 0
    
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        
        if use_pgd:
            # PGD attack now works correctly even in eval mode
            images = generate_pgd_attack(triplet_model, images, epsilon=8/255, num_iter=20)
        
        with torch.no_grad():
            embeddings = get_embedding_safe(triplet_model, images)
            outputs = pred_network(embeddings)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    return 100. * correct / total


# ==================== Main ====================
def main():
    print("\n" + "="*80)
    print("ART-AL: Adversarially-Robust Triplet Networks (FULLY FIXED)")
    print("Target: Clean 78-82%, Robust 46-52% | Time: ~10-11 hours on 2xT4")
    print("="*80 + "\n")
    
    BATCH_SIZE = 320
    EMBEDDING_DIM = 128
    NUM_EPOCHS_TRIPLET = 45
    NUM_EPOCHS_PRED = 55
    LR_TRIPLET = 0.0008
    LR_PRED = 0.08
    
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])
    
    print("Loading CIFAR-10...")
    train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                                  download=True, transform=transform_train)
    test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                                 download=True, transform=transform_test)
    
    triplet_train_dataset = TripletDataset(train_dataset, train=True)
    triplet_train_loader = DataLoader(triplet_train_dataset, batch_size=BATCH_SIZE,
                                      shuffle=True, num_workers=4, pin_memory=True)
    
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE,
                             shuffle=True, num_workers=4, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE,
                            shuffle=False, num_workers=4, pin_memory=True)
    
    print("Initializing models...")
    embedding_net = ResNet18Embedding(embedding_dim=EMBEDDING_DIM, dropout=0.1).to(device)
    triplet_model = TripletNetwork(embedding_net).to(device)
    pred_network = PredictionNetwork(embedding_dim=EMBEDDING_DIM, num_classes=10).to(device)
    
    if torch.cuda.device_count() > 1:
        print(f"Using {torch.cuda.device_count()} GPUs with DataParallel")
        triplet_model = nn.DataParallel(triplet_model)
        pred_network = nn.DataParallel(pred_network)
    
    criterion = ARTALLoss(m1=0.5, m2=2.0)
    optimizer_triplet = optim.AdamW(triplet_model.parameters(), lr=LR_TRIPLET, weight_decay=5e-4)
    scheduler_triplet = optim.lr_scheduler.CosineAnnealingLR(optimizer_triplet, T_max=NUM_EPOCHS_TRIPLET)
    
    print("\n" + "="*80)
    print("PHASE 1: Triplet Network Training with Progressive Adversarial Hardening")
    print("="*80)
    start = time.time()
    
    for epoch in range(1, NUM_EPOCHS_TRIPLET + 1):
        epoch_start = time.time()
        avg_loss, avg_intra, avg_inter, avg_adv = train_triplet_network(
            triplet_model, criterion, optimizer_triplet, triplet_train_loader, epoch, use_adv=True
        )
        scheduler_triplet.step()
        
        epoch_time = time.time() - epoch_start
        print(f'\nEpoch {epoch}/{NUM_EPOCHS_TRIPLET} ({epoch_time/60:.1f}min) | '
              f'Loss: {avg_loss:.4f} | Intra: {avg_intra:.4f} | Inter: {avg_inter:.4f} | Adv: {avg_adv:.4f}')
        
        if epoch % 15 == 0:
            print("Updating embeddings for hard mining...")
            triplet_model.eval()
            all_embeddings = []
            with torch.no_grad():
                for images, _ in train_loader:
                    images = images.to(device)
                    emb = get_embedding_safe(triplet_model, images)
                    all_embeddings.append(emb.cpu())
            all_embeddings = torch.cat(all_embeddings)
            triplet_train_dataset.update_embeddings(all_embeddings)
            torch.save(triplet_model.state_dict(), f'checkpoint_triplet_{epoch}.pth')
            print(f"âœ“ Checkpoint saved")
    
    phase1_time = time.time() - start
    print(f"\nPhase 1 completed in {phase1_time/3600:.2f} hours")
    
    print("\n" + "="*80)
    print("PHASE 2: Adversarial Partial Training (APT) of Prediction Network")
    print("="*80)
    
    for param in triplet_model.parameters():
        param.requires_grad = False
    triplet_model.eval()
    
    optimizer_pred = optim.SGD(pred_network.parameters(), lr=LR_PRED, 
                               momentum=0.9, weight_decay=5e-4, nesterov=True)
    scheduler_pred = optim.lr_scheduler.CosineAnnealingLR(optimizer_pred, T_max=NUM_EPOCHS_PRED)
    
    best_clean = 0
    best_robust = 0
    best_epoch = 0
    
    for epoch in range(1, NUM_EPOCHS_PRED + 1):
        train_loss, train_acc = train_prediction_network(
            triplet_model, pred_network, optimizer_pred, train_loader, epoch, use_adv=True
        )
        scheduler_pred.step()
        
        if epoch % 5 == 0 or epoch == NUM_EPOCHS_PRED:
            print(f"Evaluating epoch {epoch}...")
            clean_acc = test_accuracy(triplet_model, pred_network, test_loader, use_pgd=False)
            robust_acc = test_accuracy(triplet_model, pred_network, test_loader, use_pgd=True)
            
            print(f'Epoch {epoch:2d}/{NUM_EPOCHS_PRED} | Train: {train_acc:.1f}% | '
                  f'Clean: {clean_acc:.2f}% | Robust(PGD-20): {robust_acc:.2f}%')
            
            if clean_acc + robust_acc > best_clean + best_robust:
                best_clean = clean_acc
                best_robust = robust_acc
                best_epoch = epoch
                torch.save({
                    'epoch': epoch,
                    'triplet_model': triplet_model.state_dict(),
                    'pred_network': pred_network.state_dict(),
                    'clean_acc': clean_acc,
                    'robust_acc': robust_acc
                }, 'best_artal_final.pth')
                print(f'  âœ“ NEW BEST! Clean {clean_acc:.2f}% + Robust {robust_acc:.2f}%')
        else:
            print(f'Epoch {epoch:2d}/{NUM_EPOCHS_PRED} | Loss: {train_loss:.3f} | Train Acc: {train_acc:.2f}%')
    
    total_time = time.time() - start
    
    print("\n" + "="*80)
    print("FINAL RESULTS - ART-AL Framework")
    print("="*80)
    print(f"Best Clean Accuracy:  {best_clean:.2f}% (epoch {best_epoch})")
    print(f"Best Robust Accuracy: {best_robust:.2f}% (PGD-20, Îµ=8/255)")
    print(f"Combined Score:       {best_clean + best_robust:.2f}%")
    print(f"Total Training Time:  {total_time/3600:.2f} hours")
    print(f"\nModel saved to: 'best_artal_final.pth'")
    print("="*80)
    
    if best_clean >= 78 and best_robust >= 46:
        print("\nðŸŽ‰ SUCCESS! Both targets achieved!")
        print(f"   Clean:  {best_clean:.2f}% (target: â‰¥78%)")
        print(f"   Robust: {best_robust:.2f}% (target: â‰¥46%)")
    else:
        print(f"\nProgress toward targets:")
        print(f"   Clean:  {best_clean:.2f}% / 78% (gap: {max(0, 78-best_clean):.2f}%)")
        print(f"   Robust: {best_robust:.2f}% / 46% (gap: {max(0, 46-best_robust):.2f}%)")
    
    return triplet_model, pred_network, best_clean, best_robust


if __name__ == "__main__":
    main()

Using device: cuda
GPU: Tesla T4
Available GPUs: 2

ART-AL: Adversarially-Robust Triplet Networks (FULLY FIXED)
Target: Clean 78-82%, Robust 46-52% | Time: ~10-11 hours on 2xT4

Loading CIFAR-10...


100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 170M/170M [00:02<00:00, 72.5MB/s] 


Initializing models...
Using 2 GPUs with DataParallel

PHASE 1: Triplet Network Training with Progressive Adversarial Hardening
[  0/157] L: 0.696 | In: 1.384 | Int: 0.189 | Adv: 0.000 | Î»: 0.58 | Î³: 0.00
[ 75/157] L: 0.451 | In: 1.273 | Int: 0.245 | Adv: 0.000 | Î»: 0.80 | Î³: 0.00
[150/157] L: 0.421 | In: 1.253 | Int: 0.213 | Adv: 0.000 | Î»: 0.80 | Î³: 0.00

Epoch 1/45 (1.1min) | Loss: 0.4550 | Intra: 1.2503 | Inter: 0.2461 | Adv: 0.0000
[  0/157] L: 0.391 | In: 1.249 | Int: 0.176 | Adv: 0.000 | Î»: 0.80 | Î³: 0.00
[ 75/157] L: 0.416 | In: 1.266 | Int: 0.204 | Adv: 0.000 | Î»: 0.80 | Î³: 0.00
[150/157] L: 0.399 | In: 1.229 | Int: 0.191 | Adv: 0.000 | Î»: 0.80 | Î³: 0.00

Epoch 2/45 (1.2min) | Loss: 0.4129 | Intra: 1.2637 | Inter: 0.2002 | Adv: 0.0000
[  0/157] L: 0.421 | In: 1.220 | Int: 0.221 | Adv: 0.000 | Î»: 0.80 | Î³: 0.00
[ 75/157] L: 0.395 | In: 1.234 | Int: 0.185 | Adv: 0.000 | Î»: 0.80 | Î³: 0.00
[150/157] L: 0.370 | In: 1.196 | Int: 0.163 | Adv: 0.000 | Î»: 0.80 | Î³: 0.

  return F.linear(input, self.weight, self.bias)


Epoch  1/55 | Loss: 1.054 | Train Acc: 76.71%
Epoch  2/55 | Loss: 1.009 | Train Acc: 77.78%
Epoch  3/55 | Loss: 1.001 | Train Acc: 78.15%
Epoch  4/55 | Loss: 0.999 | Train Acc: 77.97%
Evaluating epoch 5...
Epoch  5/55 | Train: 78.1% | Clean: 76.69% | Robust(PGD-20): 68.66%
  âœ“ NEW BEST! Clean 76.69% + Robust 68.66%
Epoch  6/55 | Loss: 0.994 | Train Acc: 78.40%
Epoch  7/55 | Loss: 0.993 | Train Acc: 78.45%
Epoch  8/55 | Loss: 0.992 | Train Acc: 78.40%
Epoch  9/55 | Loss: 1.060 | Train Acc: 75.39%
Evaluating epoch 10...
Epoch 10/55 | Train: 75.6% | Clean: 76.51% | Robust(PGD-20): 69.10%
  âœ“ NEW BEST! Clean 76.51% + Robust 69.10%
Epoch 11/55 | Loss: 1.054 | Train Acc: 75.66%
Epoch 12/55 | Loss: 1.064 | Train Acc: 75.22%
Epoch 13/55 | Loss: 1.056 | Train Acc: 75.42%
Epoch 14/55 | Loss: 1.061 | Train Acc: 75.18%
Evaluating epoch 15...
Epoch 15/55 | Train: 75.3% | Clean: 76.66% | Robust(PGD-20): 69.37%
  âœ“ NEW BEST! Clean 76.66% + Robust 69.37%
Epoch 16/55 | Loss: 1.063 | Train Acc: 75