In [2]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F

In [3]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

In [4]:
class ZigZag_ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ZigZag_ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 128, num_blocks[3], stride=2)
        self.layer5 = self._make_layer(block, 64, num_blocks[4], stride=2)
        self.layer6 = self._make_layer(block, 128, num_blocks[5], stride=2)
        self.layer7 = self._make_layer(block, 256, num_blocks[6], stride=2)
        self.linear = nn.Linear(256*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.layer5(out)
        out = self.layer6(out)
        out = self.layer7(out)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

zz_model = ZigZag_ResNet(BasicBlock, [2, 2, 2, 2, 2, 1, 1])
num_params = sum(p.numel() for p in zz_model.parameters() if p.requires_grad)
print(f"Num Params: {num_params}\n")

Num Params: 4891338



In [5]:
def train_val(model, criterion, optimizer, train_loader, val_loader, device, scheduler = None, use_scheduler = True):
    model.train()
    train_loss = 0
    correct = 0
    count = 0
    total = 0
    for i, data in enumerate(train_loader, 0):
        image, label = data
        image = image.to(device)
        label = label.to(device)
    
        optimizer.zero_grad()
        output = model(image)
        loss = criterion(output, label)

        train_loss += loss.item()

        pred = torch.max(output.data, 1)[1]
        cur_correct = (pred == label).sum().item()
        cur_loss = loss.item()

        loss.backward()

        optimizer.step()
        
        total += label.size(0)
        correct += cur_correct
        train_loss += cur_loss

    train_accuracy = correct/total
    train_loss = train_loss/len(train_loader)
    
    model.eval()
    valid_loss = 0
    correct = 0
    count = 0
    total = 0
    for i, data in enumerate(val_loader, 0):
        image, label = data
        image = image.to(device)
        label = label.to(device)
                
        output = model(image)
        loss = criterion(output, label)

        pred = torch.max(output.data, 1)[1]
        cur_correct = (pred == label).sum().item()
        cur_loss = loss.item()
            
        total += label.size(0)
        correct += cur_correct
        valid_loss += cur_loss

    valid_accuracy = correct/total
    valid_loss = valid_loss/len(val_loader)
    
    if use_scheduler:
        scheduler.step(valid_accuracy)

    return train_loss, train_accuracy, valid_loss, valid_accuracy

def test(model, criterion, dataloader, device):
    model.eval()
    test_loss = 0
    correct = 0
    count = 0
    total = 0
    for i, data in enumerate(dataloader, 0):
        image, label = data
        image = image.to(device)
        label = label.to(device)
                
        output = model(image)
        loss = criterion(output, label)

        pred = torch.max(output.data, 1)[1]
        cur_correct = (pred == label).sum().item()
        cur_loss = loss.item()
            
        total += label.size(0)
        correct += cur_correct
        test_loss += cur_loss

    accuracy = correct/total
    test_loss = test_loss/len(dataloader)

    return test_loss, accuracy

In [8]:
import numpy as np

class ZigZagLROnPlateauRestarts:
    def __init__(self, optimizer, mode='min', lr=0.01, up_factor=1.1, down_factor=0.8, up_patience=10, down_patience=10, restart_after=30, verbose=False):
        self.optimizer = optimizer
        self.mode = mode
        self.lr = lr
        self.up_factor = up_factor
        self.down_factor = down_factor
        self.up_patience = up_patience
        self.down_patience = down_patience
        self.restart_after = restart_after
        self.verbose = verbose
        self.best_metric = np.inf if mode == 'min' else -np.inf
        self.local_best_metric = self.best_metric
        self.best_lr = lr
        self.num_bad_epochs = 0
        self.num_good_epochs = 0
        self.local_best_epoch = 0
        self.num_epochs = 0
        
    def step(self, metric):
        if self.mode == 'min':
            is_best = metric < self.best_metric
        else:
            is_best = metric > self.best_metric
            
        # If current metric is better than best metric, update best metric and learning rate
        if is_best:
            self.best_metric = metric
            self.best_lr = self.lr
            self.local_best_metric = metric
            self.local_best_lr = self.lr
            self.local_best_epoch = self.num_epochs
            self.num_good_epochs = 0
        else:
            self.num_good_epochs += 1
            
        # If current metric is worse than best metric, increment bad epochs counter
        if not is_best:
            self.num_bad_epochs += 1
        
        # If bad epochs exceed up_patience, increase learning rate
        if self.num_bad_epochs > self.up_patience:
            self.lr *= self.up_factor
            self.num_bad_epochs = 0
            self.num_good_epochs = 0
        
        # If good epochs exceed down_patience, decrease learning rate
        if self.num_good_epochs > self.down_patience:
            self.lr *= self.down_factor
            self.num_bad_epochs = 0
            self.num_good_epochs = 0
        
        # Restart learning rate after a certain number of epochs
        if self.num_epochs % self.restart_after == 0:
            self.optimizer.param_groups[0]['lr'] = self.best_lr
            if self.verbose:
                print(f"Restarting learning rate of group 0 to {self.best_lr:.4e}.")
            self.local_best_lr = self.best_lr
            self.num_bad_epochs = 0
            self.num_good_epochs = 0
        
        self.num_epochs += 1
        
        # If local best epoch is more than 10 epochs ago, update learning rate to local best learning rate
        if self.num_epochs - self.local_best_epoch > 10:
            self.optimizer.param_groups[0]['lr'] = self.local_best_lr
        
        return self.optimizer.param_groups[0]['lr']


In [9]:
batch_size = 32
    
transform_train = torchvision.transforms.Compose([
  torchvision.transforms.RandomCrop(32, padding=4),
  torchvision.transforms.RandomHorizontalFlip(),
  torchvision.transforms.RandomResizedCrop(32, scale=(0.8, 1.0), ratio=(0.8, 1.2)),
  torchvision.transforms.ToTensor(),
  torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

transform_test = torchvision.transforms.Compose([
                  torchvision.transforms.ToTensor(), 
                  torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

trainset = torchvision.datasets.CIFAR10(root = './data', train = True, download = True, transform = transform_train)

# Split the train data into train and validation sets
train_size = int(0.8 * len(trainset))
val_size = len(trainset) - train_size
trainset, valset = torch.utils.data.random_split(trainset, [train_size, val_size])

# train_size = int(0.2 * len(trainset))
# val_size = int(0.2 * len(valset))
# trainset, _ = torch.utils.data.random_split(trainset, [train_size, len(trainset) - train_size])
# valset, _ = torch.utils.data.random_split(valset, [val_size, len(valset) - val_size])

testset = torchvision.datasets.CIFAR10(root = './data', train = False, download = True, transform = transform_test)

train_loader = torch.utils.data.DataLoader(trainset, batch_size = batch_size, shuffle = True)
val_loader   = torch.utils.data.DataLoader(valset, batch_size = batch_size, shuffle = True)
test_loader  = torch.utils.data.DataLoader(testset, batch_size = batch_size, shuffle = False)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    
best_test_acc = 0

torch.cuda.empty_cache()
model = ZigZag_ResNet(BasicBlock, [2, 2, 2, 2, 2, 1, 1])
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Num Params: {num_params}\n")
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), 0.01, momentum = 0.8, weight_decay = 0.0005 , nesterov=True)

scheduler = ZigZagLROnPlateauRestarts(optimizer, mode='max', lr=0.01,
                                      up_factor=0.3, down_factor=0.5, 
                                      up_patience=1, down_patience=1, 
                                      restart_after=30, verbose = True)
train_losses_ = []
train_accuracies_ = []
valid_losses_ = []
valid_accuracies_ = []

epochs = 250
for epoch in range(epochs):
    print(f"\n\tEpoch: {epoch}")

    train_loss, train_accuracy, val_loss, val_accuracy = train_val(model, criterion, optimizer, 
                                                                train_loader, val_loader, device,
                                                                scheduler = scheduler, use_scheduler = True)
    train_losses_.append(train_loss)
    train_accuracies_.append(train_accuracy)
    valid_losses_.append(val_loss)
    valid_accuracies_.append(val_accuracy)
    print(f"\tTraining Loss: {round(train_loss, 4)}; Training Accuracy: {round(train_accuracy*100, 4)}%")
    print(f"\tValidation Loss: {round(val_loss, 4)}; Validation Accuracy: {round(val_accuracy*100, 4)}%")

test_loss, test_accuracy = test(model, criterion, test_loader, device)
print(f"\n\tTesting Loss: {round(test_loss, 4)}; Testing Accuracy: {round(test_accuracy*100, 4)}%")

if test_accuracy > best_test_acc:
    best_test_acc = test_accuracy

    metrics_dict = {'train_loss': train_losses_, 'train_accuracy': train_accuracies_, 
                  'valid_loss': valid_losses_, 'valid_accuracy': valid_accuracies_,
                  'test_loss': test_loss, 'test_accuracy': test_accuracy}

Files already downloaded and verified
Files already downloaded and verified
Num Params: 4891338


	Epoch: 0
Restarting learning rate of group 0 to 1.0000e-02.
	Training Loss: 3.5444; Training Accuracy: 34.3275%
	Validation Loss: 1.5409; Validation Accuracy: 42.1%

	Epoch: 1
	Training Loss: 2.8075; Training Accuracy: 48.9475%
	Validation Loss: 1.338; Validation Accuracy: 52.09%

	Epoch: 2
	Training Loss: 2.3219; Training Accuracy: 58.825%
	Validation Loss: 1.0754; Validation Accuracy: 62.18%

	Epoch: 3
	Training Loss: 1.9578; Training Accuracy: 65.4675%
	Validation Loss: 0.8788; Validation Accuracy: 69.3%

	Epoch: 4
	Training Loss: 1.7204; Training Accuracy: 70.085%
	Validation Loss: 0.7922; Validation Accuracy: 72.37%

	Epoch: 5
	Training Loss: 1.5575; Training Accuracy: 73.07%
	Validation Loss: 0.7392; Validation Accuracy: 74.5%

	Epoch: 6
	Training Loss: 1.4351; Training Accuracy: 75.34%
	Validation Loss: 0.7501; Validation Accuracy: 73.89%

	Epoch: 7
	Training Loss: 1.3287; Training

KeyboardInterrupt: 