In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets

from resnet import resnet

from feather import Pruner

import os
import torch
import torch.optim as optim
import torch.nn as nn
from datetime import datetime

from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
transform_train = transforms.Compose([
    transforms.RandomCrop(size=32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                         std=[0.2023, 0.1994, 0.201]),
])

transform_val = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                         std=[0.2023, 0.1994, 0.201]),
])


train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_val)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False, num_workers=2)

In [None]:
hyperparameters = {
    'sparsity_type': "feather", # nm / entropy / feather / spartan / ses / base (no sparsity)
    'epochs': 100,
    'lr': 0.1,
    'momentum': 0.9,
    'weight_decay': 5e-4,
    'batch_size': len(train_loader),
}

In [None]:
sparsity_type = hyperparameters['sparsity_type']
output_base_path = "/home/sg666/Class/ECE661/outputs"
sparsity_folder_path = os.path.join(output_base_path, sparsity_type)

hyperparameter_str = f"epochs_{hyperparameters['epochs']}_lr_{hyperparameters['lr']}_momentum_{hyperparameters['momentum']}_wd_{hyperparameters['weight_decay']}_batch_{hyperparameters['batch_size']}"
output_folder = os.path.join(sparsity_folder_path, hyperparameter_str)

os.makedirs(output_folder, exist_ok=True)

hyperparameter_file = os.path.join(output_folder, 'hyperparameters.txt')
with open(hyperparameter_file, 'w') as f:
    for key, value in hyperparameters.items():
        f.write(f"{key}: {value}\n")

def train(model, train_loader, criterion, optimizer, epoch, log_file):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}", ncols=100)
    
    for batch_idx, (inputs, targets) in enumerate(pbar):
        if sparsity_type == 'feather': pruner.update_thresh()

        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, targets)

        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        correct += predicted.eq(targets).sum().item()
        total += targets.size(0)

        pbar.set_postfix(loss=running_loss/(batch_idx+1), accuracy=100.0 * correct / total)

    avg_loss = running_loss / len(train_loader)
    accuracy = 100.0 * correct / total
    log_file.write(f'Epoch [{epoch+1}], Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%\n')
    if sparsity_type == 'feather': pruner.update_thresh(end_of_batch=True)
    return avg_loss, accuracy


def test(model, test_loader, criterion, log_file):
    model.eval()
    correct = 0
    total = 0
    test_loss = 0.0

    pbar = tqdm(test_loader, desc="Testing", ncols=100)
    
    with torch.no_grad():
        for inputs, targets in pbar:
            inputs, targets = inputs.to(device), targets.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, targets)
            test_loss += loss.item()

            _, predicted = outputs.max(1)
            correct += predicted.eq(targets).sum().item()
            total += targets.size(0)

            pbar.set_postfix(loss=test_loss/(total + inputs.size(0)), accuracy=100.0 * correct / total)

    avg_test_loss = test_loss / len(test_loader)
    accuracy = 100.0 * correct / total
    log_file.write(f'Test Loss: {avg_test_loss:.4f}, Accuracy: {accuracy:.2f}%\n')

    return avg_test_loss, accuracy


resnet20_model = resnet()
resnet20_model.to(device)

# TODO: decide how robust our sparsity implementations will be. I suggest following what the feather paper did and using a pruner class.
if sparsity_type == 'base':
    pruner = resnet20_model
elif sparsity_type == 'feather':
    # Initialize the pruner with the model, device, and desired sparsity
    pruner = Pruner(resnet20_model, device, final_rate=0.95, nbatches=hyperparameters['batch_size'], epochs=hyperparameters['epochs'])

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(resnet20_model.parameters(), lr=hyperparameters['lr'], 
                      momentum=hyperparameters['momentum'], weight_decay=hyperparameters['weight_decay'])

# TODO: implement a learning rate scheduler

log_file_path = os.path.join(output_folder, 'training_log.txt')
with open(log_file_path, 'w') as log_file:
    log_file.write(f"Training started at {datetime.now()}\n")

    # I don't think we should be deliniating between nm and other sparsity types in the training loop
    best_accuracy = 0.0

    for epoch in range(hyperparameters['epochs']):
        train_loss, train_accuracy = train(resnet20_model, train_loader, criterion, optimizer, epoch, log_file)
        test_loss, test_accuracy = test(resnet20_model, test_loader, criterion, log_file)

        if test_accuracy > best_accuracy:
            best_accuracy = test_accuracy
            model_checkpoint_path = os.path.join(output_folder, f"model_best.pth")
            torch.save(resnet20_model.state_dict(), model_checkpoint_path)
            print(f"Saved best model at epoch {epoch+1} with accuracy: {best_accuracy:.2f}%")
    
    # best_accuracy = 0.0

    # if sparsity_type == 'nm':
    #     for epoch in range(hyperparameters['epochs']):
    #         train_loss, train_accuracy = train(resnet20_model, train_loader, criterion, optimizer, epoch, log_file)
    #         test_loss, test_accuracy = test(resnet20_model, test_loader, criterion, log_file)

    #         if test_accuracy > best_accuracy:
    #             best_accuracy = test_accuracy
    #             model_checkpoint_path = os.path.join(output_folder, f"model_best.pth")
    #             torch.save(resnet20_model.state_dict(), model_checkpoint_path)
    #             print(f"Saved best model at epoch {epoch+1} with accuracy: {best_accuracy:.2f}%")

    # else:
    #     if sparsity_type == "entropy":
    #         pass  # Apply entropy sparsity
    #     elif sparsity_type == "feather":
    #         pass # Apply feather sparsity
    #     elif sparsity_type == "spartan":
    #         pass  # Apply spartan sparsity
    #     elif sparsity_type == 'ses':
    #         pass  # Apply SES sparsity
    #     elif sparsity_type == "base":
    #         pass  # Apply no sparsity

    #     for epoch in range(hyperparameters['epochs']):
    #         train_loss, train_accuracy = train(resnet20_model, train_loader, criterion, optimizer, epoch, log_file)
    #         test_loss, test_accuracy = test(resnet20_model, test_loader, criterion, log_file)

    #         if test_accuracy > best_accuracy:
    #             best_accuracy = test_accuracy
    #             model_checkpoint_path = os.path.join(output_folder, f"model_best.pth")
    #             torch.save(resnet20_model.state_dict(), model_checkpoint_path)
    #             print(f"Saved best model at epoch {epoch+1} with accuracy: {best_accuracy:.2f}%")

    log_file.write(f"Training completed at {datetime.now()}\n")
    log_file.write(f"Best model accuracy: {best_accuracy:.2f}%\n")


In [None]:
if sparsity_type == "base":
    print("Base model training completed.")
elif sparsity_type == "feather":
    print("Feather sparsity training completed.")
    pr = pruner.print_sparsity()
    print(f"prune rate : {pr}" )
    pruner.desparsify()
    torch.save(resnet20_model.state_dict(), model_checkpoint_path)
