<a href="https://colab.research.google.com/github/uicids560/Efficient-Deep-Training/blob/main/CNN_CIFAR10.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/decile-team/cords.git
%cd cords/
%ls
!pip install dotmap
!pip install apricot-select
!pip install ray[default]
!pip install ray[tune]

import time
import numpy as np
import logging
import os
import os.path as osp
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from cords.utils.data.datasets.SL import gen_dataset
from cords.utils.config_utils import load_config_data
from cords.utils.data.data_utils import WeightedSubset
from torch.utils.data import Subset
from ray import tune

fatal: destination path 'cords' already exists and is not an empty directory.
/content/cords/cords
__init__.py  [0m[01;34m__pycache__[0m/  [01;34mselectionstrategies[0m/  test.py  [01;34mutils[0m/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


CIFAR 10 Dataset

In [None]:
trainset, validset, testset, num_cls = gen_dataset('data/', 'cifar10', None, isnumpy=False)

trn_batch_size = 20
val_batch_size = 20
tst_batch_size = 1000

# Creating the Data Loaders
trainloader = torch.utils.data.DataLoader(trainset, batch_size=trn_batch_size,
                                          shuffle=False, pin_memory=True)

valloader = torch.utils.data.DataLoader(validset, batch_size=val_batch_size,
                                        shuffle=False, pin_memory=True)

testloader = torch.utils.data.DataLoader(testset, batch_size=tst_batch_size,
                                          shuffle=False, pin_memory=True)
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting data/cifar-10-python.tar.gz to data/
Files already downloaded and verified


CNN Model, Loss Function & Optimizer Definition

In [None]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


model = Net()

#from cords.utils.models import CNN
device = 'cuda' 
model = model.to(device)

#Loss Functions
criterion = nn.CrossEntropyLoss()

#Optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
#T_max is the maximum number of scheduler steps. Here we are using the number of epochs as the maximum number of scheduler steps.

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                       T_max=300)

Definitions

In [None]:
def generate_cumulative_timing(mod_timing):
    tmp = 0
    mod_cum_timing = np.zeros(len(mod_timing))
    for i in range(len(mod_timing)):
        tmp += mod_timing[i]
        mod_cum_timing[i] = tmp
    return mod_cum_timing / 3600

def __get_logger(results_dir):
  os.makedirs(results_dir, exist_ok=True)
  # setup logger
  plain_formatter = logging.Formatter("[%(asctime)s] %(name)s %(levelname)s: %(message)s",
                                      datefmt="%m/%d %H:%M:%S")
  logger = logging.getLogger(__name__)
  logger.setLevel(logging.INFO)
  s_handler = logging.StreamHandler(stream=sys.stdout)
  s_handler.setFormatter(plain_formatter)
  s_handler.setLevel(logging.INFO)
  logger.addHandler(s_handler)
  f_handler = logging.FileHandler(os.path.join(results_dir, "results.log"))
  f_handler.setFormatter(plain_formatter)
  f_handler.setLevel(logging.DEBUG)
  logger.addHandler(f_handler)
  logger.propagate = False
  return logger
#Results logging directory
results_dir = osp.abspath(osp.expanduser('results'))
logger = __get_logger(results_dir)

#Evaluation Metrics
trn_losses = list()
val_losses = list()
tst_losses = list()
timing = list()
trn_acc = list()
val_acc = list()  
tst_acc = list()  
#Checkpointing Metrics
save_every= 10
is_save = True

Training Loop

In [None]:
#Arguments
num_epochs = 50
print_every = 1
print_args = ["val_loss", "val_acc", "tst_loss", "tst_acc", "time"]
timing= list()
num_samples = 20
#Training Loop
for epoch in range(num_epochs):
    model.train()
    start_time = time.time()
    for _, (inputs, targets) in enumerate(trainloader):
        inputs = inputs.to(device)
        targets = targets.to(device, non_blocking=True)
        optimizer.zero_grad()
        outputs = model(inputs)
        losses = criterion(outputs, targets)
        optimizer.step()
        _, predicted = outputs.max(1)
    epoch_time = time.time() - start_time
    scheduler.step()
    timing.append(epoch_time)

In [None]:
#Evaluation Metrics
trn_losses = list()
val_losses = list()
tst_losses = list()
trn_acc = list()
val_acc = list()  
tst_acc = list()  
if (epoch + 1) % print_every == 0:
        trn_loss = 0
        trn_correct = 0
        trn_total = 0
        val_loss = 0
        val_correct = 0
        val_total = 0
        tst_correct = 0
        tst_total = 0
        tst_loss = 0
        model.eval()

        if ("trn_loss" in print_args) or ("trn_acc" in print_args):
            with torch.no_grad():
                for _, (inputs, targets) in enumerate(trainloader):
                    inputs, targets = inputs.to(device), \
                                      targets.to(device, non_blocking=True)
                    outputs = model(inputs)
                    loss = criterion(outputs, targets)
                    trn_loss += loss.item()
                    if "trn_acc" in print_args:
                        _, predicted = outputs.max(1)
                        trn_total += targets.size(0)
                        trn_correct += predicted.eq(targets).sum().item()
                trn_losses.append(trn_loss)

            if "trn_acc" in print_args:
                trn_acc.append(trn_correct / trn_total)

        if ("val_loss" in print_args) or ("val_acc" in print_args):
            with torch.no_grad():
                for _, (inputs, targets) in enumerate(valloader):
                    inputs, targets = inputs.to(device), \
                                      targets.to(device, non_blocking=True)
                    outputs = model(inputs)
                    loss = criterion(outputs, targets)
                    val_loss += loss.item()
                    if "val_acc" in print_args:
                        _, predicted = outputs.max(1)
                        val_total += targets.size(0)
                        val_correct += predicted.eq(targets).sum().item()
                val_losses.append(val_loss)

            if "val_acc" in print_args:
                val_acc.append(val_correct / val_total)

        if ("tst_loss" in print_args) or ("tst_acc" in print_args):
            with torch.no_grad():
                for _, (inputs, targets) in enumerate(testloader):
                    inputs, targets = inputs.to(device), \
                                      targets.to(device, non_blocking=True)
                    outputs = model(inputs)
                    loss = criterion(outputs, targets)
                    tst_loss += loss.item()
                    if "tst_acc" in print_args:
                        _, predicted = outputs.max(1)
                        tst_total += targets.size(0)
                        tst_correct += predicted.eq(targets).sum().item()
                tst_losses.append(tst_loss)

            if "tst_acc" in print_args:
                tst_acc.append(tst_correct / tst_total)

        print_str = "Epoch: " + str(epoch + 1)


Results

In [None]:
for arg in print_args:
  if arg == "val_loss":
    print_str += " , " + "Validation Loss: " + str(val_losses[-1])
  if arg == "val_acc":
    print_str += " , " + "Validation Accuracy: " + str(val_acc[-1])
  if arg == "tst_loss":
    print_str += " , " + "Test Loss: " + str(tst_losses[-1])
  if arg == "tst_acc":
    print_str += " , " + "Test Accuracy: " + str(tst_acc[-1])
  if arg == "trn_loss":
    print_str += " , " + "Training Loss: " + str(trn_losses[-1])
  if arg == "trn_acc":
    print_str += " , " + "Training Accuracy: " + str(trn_acc[-1])
  if arg == "time":
    print_str += " , " + "Timing: " + str(timing[-1])

logger.info(print_str)

if "val_acc" in print_args:
    val_str = "Validation Accuracy, "
    for val in val_acc:
        val_str = val_str + " , " + str(val)
    logger.info(val_str)

if "tst_acc" in print_args:
    tst_str = "Test Accuracy, "
    for tst in tst_acc:
        tst_str = tst_str + " , " + str(tst)
    logger.info(tst_str)

if "time" in print_args:
    time_str = "Time, "
    for t in timing:
        time_str = time_str + " , " + str(t)
    logger.info(timing)


[10/13 21:15:34] __main__ INFO: Epoch: 50 , Validation Loss: 576.9893550872803 , Validation Accuracy: 0.104 , Test Loss: 23.058031797409058 , Test Accuracy: 0.1099 , Timing: 15.831066370010376 , Validation Loss: 576.9893550872803 , Validation Accuracy: 0.104 , Test Loss: 23.058031797409058 , Test Accuracy: 0.1099 , Timing: 15.831066370010376
[10/13 21:15:34] __main__ INFO: Epoch: 50 , Validation Loss: 576.9893550872803 , Validation Accuracy: 0.104 , Test Loss: 23.058031797409058 , Test Accuracy: 0.1099 , Timing: 15.831066370010376 , Validation Loss: 576.9893550872803 , Validation Accuracy: 0.104 , Test Loss: 23.058031797409058 , Test Accuracy: 0.1099 , Timing: 15.831066370010376
[10/13 21:15:34] __main__ INFO: Validation Accuracy,  , 0.104
[10/13 21:15:34] __main__ INFO: Validation Accuracy,  , 0.104
[10/13 21:15:34] __main__ INFO: Test Accuracy,  , 0.1099
[10/13 21:15:34] __main__ INFO: Test Accuracy,  , 0.1099
[10/13 21:15:34] __main__ INFO: [16.74179196357727, 16.462966442108154, 16