In [None]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


In [None]:
def train_val(model, criterion, optimizer, train_loader, val_loader, device, scheduler = None, use_scheduler = True):
    model.train()
    train_loss = 0
    correct = 0
    count = 0
    total = 0
    for i, data in enumerate(train_loader, 0):
        image, label = data
        image = image.to(device)
        label = label.to(device)
    
        optimizer.zero_grad()
        output = model(image)
        loss = criterion(output, label)

        train_loss += loss.item()

        pred = torch.max(output.data, 1)[1]
        cur_correct = (pred == label).sum().item()
        cur_loss = loss.item()

        loss.backward()

        optimizer.step()
        
        total += label.size(0)
        correct += cur_correct
        train_loss += cur_loss

    train_accuracy = correct/total
    train_loss = train_loss/len(train_loader)
    
    model.eval()
    valid_loss = 0
    correct = 0
    count = 0
    total = 0
    for i, data in enumerate(val_loader, 0):
        image, label = data
        image = image.to(device)
        label = label.to(device)
                
        output = model(image)
        loss = criterion(output, label)

        pred = torch.max(output.data, 1)[1]
        cur_correct = (pred == label).sum().item()
        cur_loss = loss.item()
            
        total += label.size(0)
        correct += cur_correct
        valid_loss += cur_loss

    valid_accuracy = correct/total
    valid_loss = valid_loss/len(val_loader)
    
    if use_scheduler:
        scheduler.step(valid_accuracy)

    return train_loss, train_accuracy, valid_loss, valid_accuracy

def test(model, criterion, dataloader, device):
    model.eval()
    test_loss = 0
    correct = 0
    count = 0
    total = 0
    for i, data in enumerate(dataloader, 0):
        image, label = data
        image = image.to(device)
        label = label.to(device)
                
        output = model(image)
        loss = criterion(output, label)

        pred = torch.max(output.data, 1)[1]
        cur_correct = (pred == label).sum().item()
        cur_loss = loss.item()
            
        total += label.size(0)
        correct += cur_correct
        test_loss += cur_loss

    accuracy = correct/total
    test_loss = test_loss/len(dataloader)

    return test_loss, accuracy

In [None]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

In [None]:
class ZigZag_ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ZigZag_ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 128, num_blocks[3], stride=2)
        self.layer5 = self._make_layer(block, 64, num_blocks[4], stride=2)
        self.layer6 = self._make_layer(block, 128, num_blocks[5], stride=2)
        self.layer7 = self._make_layer(block, 256, num_blocks[6], stride=2)
        self.linear = nn.Linear(256*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.layer5(out)
        out = self.layer6(out)
        out = self.layer7(out)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

zz_model = ZigZag_ResNet(BasicBlock, [2, 2, 2, 2, 2, 1, 1])
num_params = sum(p.numel() for p in zz_model.parameters() if p.requires_grad)
print(f"Num Params: {num_params}\n")

Num Params: 4891338



In [None]:
batch_size = 64
    
transform_train = torchvision.transforms.Compose([
  torchvision.transforms.RandomCrop(32, padding=4),
  torchvision.transforms.RandomHorizontalFlip(),
  torchvision.transforms.RandomResizedCrop(32, scale=(0.8, 1.0), ratio=(0.8, 1.2)),
  torchvision.transforms.ToTensor(),
  torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

transform_test = torchvision.transforms.Compose([
                  torchvision.transforms.ToTensor(), 
                  torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

trainset = torchvision.datasets.CIFAR10(root = './data', train = True, download = True, transform = transform_train)

# Split the train data into train and validation sets
train_size = int(0.8 * len(trainset))
val_size = len(trainset) - train_size
trainset, valset = torch.utils.data.random_split(trainset, [train_size, val_size])

testset = torchvision.datasets.CIFAR10(root = './data', train = False, download = True, transform = transform_test)

train_loader = torch.utils.data.DataLoader(trainset, batch_size = batch_size, shuffle = True)
val_loader   = torch.utils.data.DataLoader(valset, batch_size = batch_size, shuffle = True)
test_loader  = torch.utils.data.DataLoader(testset, batch_size = batch_size, shuffle = False)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    
best_test_acc = 0

torch.cuda.empty_cache()
model = ZigZag_ResNet(BasicBlock, [2, 2, 2, 2, 2, 1, 1])
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Num Params: {num_params}\n")
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), 0.01, momentum = 0.8, weight_decay = 0.0005 , nesterov=True)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.1, patience=5, verbose = True)

train_losses_ = []
train_accuracies_ = []
valid_losses_ = []
valid_accuracies_ = []

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:05<00:00, 28657620.42it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
Num Params: 4891338



In [None]:
epochs = 100

for epoch in range(epochs):
    print(f"\n\tEpoch: {epoch}")

    train_loss, train_accuracy, val_loss, val_accuracy = train_val(model, criterion, optimizer, 
                                                                train_loader, val_loader, device,
                                                                scheduler = scheduler, use_scheduler = True)
    train_losses_.append(train_loss)
    train_accuracies_.append(train_accuracy)
    valid_losses_.append(val_loss)
    valid_accuracies_.append(val_accuracy)
    print(f"\tTraining Loss: {round(train_loss, 4)}; Training Accuracy: {round(train_accuracy*100, 4)}%")
    print(f"\tValidation Loss: {round(val_loss, 4)}; Validation Accuracy: {round(val_accuracy*100, 4)}%")

test_loss, test_accuracy = test(model, criterion, test_loader, device)
print(f"\n\tTesting Loss: {round(test_loss, 4)}; Testing Accuracy: {round(test_accuracy*100, 4)}%")

if test_accuracy > best_test_acc:
    best_test_acc = test_accuracy

    torch.save(model.state_dict(), 'zigzag_resnet_tuned.pth')

    metrics_dict = {'train_loss': train_losses_, 'train_accuracy': train_accuracies_, 
                  'valid_loss': valid_losses_, 'valid_accuracy': valid_accuracies_,
                  'test_loss': test_loss, 'test_accuracy': test_accuracy}


	Epoch: 0
	Training Loss: 3.4509; Training Accuracy: 36.6325%
	Validation Loss: 1.5205; Validation Accuracy: 44.36%

	Epoch: 1
	Training Loss: 2.7357; Training Accuracy: 50.0875%
	Validation Loss: 1.2269; Validation Accuracy: 56.2%

	Epoch: 2
	Training Loss: 2.2906; Training Accuracy: 59.1825%
	Validation Loss: 1.0767; Validation Accuracy: 62.1%

	Epoch: 3
	Training Loss: 1.9235; Training Accuracy: 66.2125%
	Validation Loss: 0.9331; Validation Accuracy: 67.33%

	Epoch: 4
	Training Loss: 1.6768; Training Accuracy: 70.7875%
	Validation Loss: 0.8811; Validation Accuracy: 69.88%

	Epoch: 5
	Training Loss: 1.5043; Training Accuracy: 73.7125%
	Validation Loss: 0.734; Validation Accuracy: 74.6%

	Epoch: 6
	Training Loss: 1.3546; Training Accuracy: 76.7075%
	Validation Loss: 0.6952; Validation Accuracy: 76.35%

	Epoch: 7
	Training Loss: 1.2548; Training Accuracy: 78.22%
	Validation Loss: 0.6533; Validation Accuracy: 77.74%

	Epoch: 8
	Training Loss: 1.1762; Training Accuracy: 79.635%
	Validat