In [1]:
import glob
import sys
import os
import time
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import scipy as spy
from torchvision import datasets
from torchvision import transforms
import matplotlib.pyplot as plt
from torch.utils.data.sampler import SubsetRandomSampler
import ssl
import pickle, json
import src.main as lc
from src.models.AlexNet import AlexNet
import src.compression.deltaCompress as lc_compress
from src.models.AlexNet_LowRank import getBase, AlexNet_LowRank, load_sd_decomp
from src.utils.utils import evaluate_accuracy, lazy_restore, evaluate_compression

In [2]:
HDFP = "/volumes/Ultra Touch" # Load HHD


# Set up training data:

def data_loader():
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])

    trainset = datasets.MNIST(root='./data', train=True,
                                          download=True, transform=transform)
    # Reintroduce the 2000 datapoints model has not seen before.
    trainset.data = trainset.data.clone()[-2000:-1000]
    trainset.targets = trainset.targets.clone()[-2000:-1000]
    trainloader = torch.utils.data.DataLoader(trainset, batch_size = 32,
                                              shuffle=False, num_workers=2)

    testset = datasets.MNIST(root='./data', train=False,
                                         download=True, transform=transform)

    testset.data = trainset.data[-1000:]
    testset.targets = trainset.targets[-1000:]
    testloader = torch.utils.data.DataLoader(testset, batch_size = 32,
                                             shuffle=False, num_workers=2)
    
    testloader = torch.utils.data.DataLoader(testset, batch_size = 32,
                                             shuffle=False, num_workers=2)
    
    return trainloader, testloader

In [3]:
# Bypass loading dataset using SSL unverified
ssl._create_default_https_context = ssl._create_unverified_context

# MNIST dataset 
train_loader, test_loader = data_loader()

In [4]:
# Set up save location on HHD.
SAVE_LOC = HDFP + "/demo"
if not os.path.exists(SAVE_LOC):
    os.makedirs(SAVE_LOC)

In [5]:
DECOMPOSED_LAYERS = ["classifier.1.weight", "classifier.4.weight"] # Set up layers to decompose
RANK = -1 # -1 => default rank of min(min(n, m), 8)
SCALING = -1 # -1 => default LoRA scaling of 0.5
BRANCH_ACC = "0.8072" # Set up branching point

# Set up weights for original AlexNet model
original = AlexNet()
learning_rate = 0.01

# Load from "branch point"
BRANCH_LOC = HDFP + "/lobranch-snapshot/branchpoints/branch_{}.pt".format(BRANCH_ACC)
original.load_state_dict(torch.load(BRANCH_LOC))

# Construct LoRA model from original model.
BASEPATH = SAVE_LOC + "/lora_base.pt" # Extract the LoRA bases at branchpoint.

# Note that this BASEPATH is only needed for the superstep - LoRA, but for now superstep-LoRA is simulated by restoring from
# a local version of the state stored in the ipynb memory instead of from the HHD.
# Will create a fully functional superstep / restore mechanism in future iterations of the mechanism.

w, b = getBase(original, BASEPATH)
model = AlexNet_LowRank(w, b, rank = RANK) # Create our low-rank model.
load_sd_decomp(torch.load(BRANCH_LOC), model, DECOMPOSED_LAYERS)
optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)

In [6]:
restored_accuracy = [] # We store our restoration accuracy here.

current_iter = 0 # Current iter represents the version of the model with respect to the super step.
current_set = 0 # Represents a set of checkpoints (9 default iterations + 1 super step).

acc = lambda x, y : (torch.max(x, 1)[1] == y).sum().item() / y.size(0)

for epch in range(20):
    for i, data in enumerate(train_loader, 0):
        print("Epoch: {}, Iteration: {}".format(epch, i))

        # ==========================
        # Compressing the Model
        # ==========================
        
        set_path = "/set_{}".format(current_set) # Set up file directory for the current set. (10 models + 1 superstep)
        if not os.path.exists(SAVE_LOC + set_path):
            os.makedirs(SAVE_LOC + set_path)

        if i == 0 and epch == 0: # first iteration, create baseline model
            base, base_decomp = lc.extract_weights(model, SAVE_LOC + 
                                                       "/set_{}".format(current_set), DECOMPOSED_LAYERS)
        else:
            if i % 10 == 0: 
                # super step process (every 10 iterations)
                new_model = AlexNet()

                # TODO: construct non-lazy restore from branchpoint (base lora weights) as well as lora supersteps.
                # For now the LoRA super step process is simulated via lazy_restore and the weights dictionary kept in ipynb memory.
                new_model = lazy_restore(base, base_decomp, bias, AlexNet(), 
                                          original.state_dict(), DECOMPOSED_LAYERS, rank = RANK, scaling = SCALING)

                # Changing the previous "original model" to aid the lazy restore (have only conducted 
                # restore lazily during evaluation, reason given below).
                original = new_model
                
                # Increment current set id & iteration id.
                current_set += 1
                current_iter = 0

                set_path = "/set_{}".format(current_set)
                if not os.path.exists(SAVE_LOC + set_path):
                    os.makedirs(SAVE_LOC + set_path)
                
                # Rebuilding LoRA layers => reset model!
                w, b = getBase(original)
                model = AlexNet_LowRank(w, b, rank = RANK)
                optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)
                load_sd_decomp(original.state_dict(), model, DECOMPOSED_LAYERS)

                # Extract new base + save current model as super step.
                base, base_decomp = lc.extract_weights(model, SAVE_LOC + 
                                                       "/set_{}".format(current_set), 
                                                       DECOMPOSED_LAYERS, restoring=False)

            else:
                # Delta-compression (Non-superstep)
                
                delta, decomp_delta, bias = lc.generate_delta(base, 
                                                                base_decomp, model.state_dict(), DECOMPOSED_LAYERS)
                compressed_delta, full_delta, compressed_dcomp_delta, full_dcomp_delta  = lc.compress_delta(delta, 
                                                                                                            decomp_delta)
                
                # Saving checkpoint
                lc.save_checkpoint(compressed_delta, compressed_dcomp_delta, bias, current_iter, SAVE_LOC + 
                                "/set_{}".format(current_set))
    
                base = np.add(base, full_delta) # Replace base with latest for delta to accumulate.
                base_decomp = np.add(full_dcomp_delta, base_decomp)

                current_iter += 1
        
        # ==========================
        # Training on Low-Rank Model
        # ==========================

        # Get the inputs and labels
        inputs, labels = data

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward + backward + optimize
        outputs = model(inputs)
        loss = torch.nn.functional.cross_entropy(outputs,labels)
        loss.backward()
        optimizer.step()

        if i != 0  and i % 5 == 0: # Evaluation on testing set

            # Restoration based on previous restored model, for now we do a lazy restore directly on current 
            # base because iterations across epochs might not align with % 10 superstep since this restoration
            # is meant to be taken without respect to epoch.
            # But in deployment, users define a set id and checkpoint id to generate the current base from 
            # before running this lazy restore process on the generated base, so it essentially just skips the base
            # construction process in a standard lc-lora restoration.

            restored_model = lazy_restore(base, base_decomp, bias, AlexNet(), 
                                          original.state_dict(), DECOMPOSED_LAYERS, 
                                          rank = RANK, scaling = SCALING)
            restored_accuracy.append(evaluate_accuracy(restored_model, test_loader))

Epoch: 0, Iteration: 0
saving full base model @ /volumes/Ultra Touch/demo/set_0/base_model.pt
Epoch: 0, Iteration: 1
Saving Checkpoint lc_checkpoint_0.pt @ /volumes/Ultra Touch/demo/set_0
Epoch: 0, Iteration: 2
Saving Checkpoint lc_checkpoint_1.pt @ /volumes/Ultra Touch/demo/set_0
Epoch: 0, Iteration: 3
Saving Checkpoint lc_checkpoint_2.pt @ /volumes/Ultra Touch/demo/set_0
Epoch: 0, Iteration: 4
Saving Checkpoint lc_checkpoint_3.pt @ /volumes/Ultra Touch/demo/set_0
Epoch: 0, Iteration: 5
Saving Checkpoint lc_checkpoint_4.pt @ /volumes/Ultra Touch/demo/set_0
model accuracy: 0.905
Epoch: 0, Iteration: 6
Saving Checkpoint lc_checkpoint_5.pt @ /volumes/Ultra Touch/demo/set_0
Epoch: 0, Iteration: 7
Saving Checkpoint lc_checkpoint_6.pt @ /volumes/Ultra Touch/demo/set_0
Epoch: 0, Iteration: 8
Saving Checkpoint lc_checkpoint_7.pt @ /volumes/Ultra Touch/demo/set_0
Epoch: 0, Iteration: 9
Saving Checkpoint lc_checkpoint_8.pt @ /volumes/Ultra Touch/demo/set_0
Epoch: 0, Iteration: 10
saving full ba

KeyboardInterrupt: 