In [None]:
batch_size = 32
    
transform_train = torchvision.transforms.Compose([
  torchvision.transforms.RandomCrop(32, padding=4),
  torchvision.transforms.RandomHorizontalFlip(),
  torchvision.transforms.RandomResizedCrop(32, scale=(0.8, 1.0), ratio=(0.8, 1.2)),
  torchvision.transforms.ToTensor(),
  torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

transform_test = torchvision.transforms.Compose([
                  torchvision.transforms.ToTensor(), 
                  torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

trainset = torchvision.datasets.CIFAR10(root = './data', train = True, download = True, transform = transform_train)

# Split the train data into train and validation sets
train_size = int(0.8 * len(trainset))
val_size = len(trainset) - train_size
trainset, valset = torch.utils.data.random_split(trainset, [train_size, val_size])

testset = torchvision.datasets.CIFAR10(root = './data', train = False, download = True, transform = transform_test)

train_loader = torch.utils.data.DataLoader(trainset, batch_size = batch_size, shuffle = True)
val_loader   = torch.utils.data.DataLoader(valset, batch_size = batch_size, shuffle = True)
test_loader  = torch.utils.data.DataLoader(testset, batch_size = batch_size, shuffle = False)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    
best_test_acc = 0

torch.cuda.empty_cache()
model = ZigZag_ResNet(BasicBlock, [2, 2, 2, 2, 2, 1, 1])
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Num Params: {num_params}\n")
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), 0.01, momentum = 0.8, weight_decay = 0.0005 , nesterov=True)

scheduler = ZigZagLROnPlateau(optimizer, mode='max', up_factor=0.3, down_factor=0.5, 
                                                     up_patience=1, down_patience=1, 
                                                     verbose = True)
train_losses_ = []
train_accuracies_ = []
valid_losses_ = []
valid_accuracies_ = []

In [None]:
import numpy as np

class ZigZagLROnPlateauRestarts(torch.optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, mode='min', lr=0.01, up_factor=1.1, down_factor=0.8, up_patience=10, down_patience=10, restart_after=30, verbose=True):
        super(ZigZagLROnPlateauRestarts).__init__()
        self.optimizer = optimizer
        self.mode = mode
        self.up_factor = 1 + up_factor
        self.down_factor = 1 - down_factor
        self.up_patience = up_patience
        self.down_patience = down_patience
        self.num_bad_epochs = 0
        self.num_good_epochs = 0
        self.best_metric = np.Inf if self.mode == 'min' else -np.Inf
        self.best_lr = lr
        self.restart_after = restart_after
        self.verbose = verbose
        self.num_epochs = 0

    def step(self, metric):
        self.num_epochs += 1
        if self.mode == 'min':
            if metric < self.best_metric:
                self.best_metric = metric
                self.best_lr = self.optimizer.param_groups[0]['lr']
                self.num_bad_epochs = 0
                self.num_good_epochs += 1
                if self.num_good_epochs > self.up_patience:
                    old_lr = self.optimizer.param_groups[0]['lr']
                    new_lr = old_lr * self.up_factor
                    self.optimizer.param_groups[0]['lr'] = new_lr
                    if self.verbose:
                        print(f"increasing learning rate of group 0 to {new_lr:.4e}.")
                    self.num_good_epochs = 0
            else:
                self.num_bad_epochs += 1
                self.num_good_epochs = 0
                if self.num_bad_epochs > self.down_patience:
                    old_lr = self.optimizer.param_groups[0]['lr']
                    new_lr = old_lr * self.down_factor
                    self.optimizer.param_groups[0]['lr'] = new_lr
                    if self.verbose:
                        print(f"reducing learning rate of group 0 to {new_lr:.4e}.")
                    self.num_bad_epochs = 0
        else:
            if metric > self.best_metric:
                self.best_metric = metric
                self.best_lr = self.optimizer.param_groups[0]['lr']
                self.num_bad_epochs = 0
                self.num_good_epochs += 1
                if self.num_good_epochs > self.up_patience:
                    old_lr = self.optimizer.param_groups[0]['lr']
                    new_lr = old_lr * self.up_factor
                    self.optimizer.param_groups[0]['lr'] = new_lr
                    if self.verbose:
                        print(f"increasing learning rate of group 0 to {new_lr:.4e}.")
                    self.num_good_epochs = 0
            else:
                self.num_bad_epochs += 1
                self.num_good_epochs = 0
                if self.num_bad_epochs > self.down_patience:
                    old_lr = self.optimizer.param_groups[0]['lr']
                    new_lr = old_lr * self.down_factor
                    self.optimizer.param_groups[0]['lr'] = new_lr
                    if self.verbose:
                        print(f"reducing learning rate of group 0 to {new_lr:.4e}.")
                    self.num_bad_epochs = 0
                    
        if self.num_epochs % self.restart_after == 0:
            self.best_metric = np.Inf if self.mode == 'min' else -np.Inf
            self.optimizer.param_groups[0]['lr'] = self.best_lr
            if self.verbose:
                print(f"restart: setting learning rate of group 0 to best learning rate value: {self.best_lr:.4e}.")

In [None]:
batch_size = 32
    
transform_train = torchvision.transforms.Compose([
  torchvision.transforms.RandomCrop(32, padding=4),
  torchvision.transforms.RandomHorizontalFlip(),
  torchvision.transforms.RandomResizedCrop(32, scale=(0.8, 1.0), ratio=(0.8, 1.2)),
  torchvision.transforms.ToTensor(),
  torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

transform_test = torchvision.transforms.Compose([
                  torchvision.transforms.ToTensor(), 
                  torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

trainset = torchvision.datasets.CIFAR10(root = './data', train = True, download = True, transform = transform_train)

# Split the train data into train and validation sets
train_size = int(0.8 * len(trainset))
val_size = len(trainset) - train_size
trainset, valset = torch.utils.data.random_split(trainset, [train_size, val_size])

# train_size = int(0.2 * len(trainset))
# val_size = int(0.2 * len(valset))
# trainset, _ = torch.utils.data.random_split(trainset, [train_size, len(trainset) - train_size])
# valset, _ = torch.utils.data.random_split(valset, [val_size, len(valset) - val_size])

testset = torchvision.datasets.CIFAR10(root = './data', train = False, download = True, transform = transform_test)

train_loader = torch.utils.data.DataLoader(trainset, batch_size = batch_size, shuffle = True)
val_loader   = torch.utils.data.DataLoader(valset, batch_size = batch_size, shuffle = True)
test_loader  = torch.utils.data.DataLoader(testset, batch_size = batch_size, shuffle = False)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    
best_test_acc = 0

torch.cuda.empty_cache()
model = ZigZag_ResNet(BasicBlock, [2, 2, 2, 2, 2, 1, 1])
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Num Params: {num_params}\n")
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), 0.01, momentum = 0.8, weight_decay = 0.0005 , nesterov=True)

scheduler = ZigZagLROnPlateauRestarts(optimizer, mode='max', lr=0.01,
                                      up_factor=0.3, down_factor=0.5, 
                                      up_patience=1, down_patience=1, 
                                      restart_after=30, verbose = True)
train_losses_ = []
train_accuracies_ = []
valid_losses_ = []
valid_accuracies_ = []

epochs = 250
for epoch in range(epochs):
    print(f"\n\tEpoch: {epoch}")

    train_loss, train_accuracy, val_loss, val_accuracy = train_val(model, criterion, optimizer, 
                                                                train_loader, val_loader, device,
                                                                scheduler = scheduler, use_scheduler = True)
    train_losses_.append(train_loss)
    train_accuracies_.append(train_accuracy)
    valid_losses_.append(val_loss)
    valid_accuracies_.append(val_accuracy)
    print(f"\tTraining Loss: {round(train_loss, 4)}; Training Accuracy: {round(train_accuracy*100, 4)}%")
    print(f"\tValidation Loss: {round(val_loss, 4)}; Validation Accuracy: {round(val_accuracy*100, 4)}%")

test_loss, test_accuracy = test(model, criterion, test_loader, device)
print(f"\n\tTesting Loss: {round(test_loss, 4)}; Testing Accuracy: {round(test_accuracy*100, 4)}%")

if test_accuracy > best_test_acc:
    best_test_acc = test_accuracy

    metrics_dict = {'train_loss': train_losses_, 'train_accuracy': train_accuracies_, 
                  'valid_loss': valid_losses_, 'valid_accuracy': valid_accuracies_,
                  'test_loss': test_loss, 'test_accuracy': test_accuracy}