In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import timeit
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from models import ConvMixer, MlpMixer, ViT
from torchvision.datasets import CIFAR10, ImageFolder
from torch.utils.data import DataLoader 
from torchvision import transforms as T

In [None]:
print(torch.cuda.is_available())
print(torch.cuda.get_device_name(0))
print(torch.cuda.empty_cache())
# print(torch.cuda.memory_summary(0))

In [None]:
batch_size=2
hdim=1024
depth=32

epochs=1

scale=0.75
reprob=0.25
ra_m=8
ra_n=1
jitter=0.1
psize=2
conv_ks=5
wd=0.01
clip_norm=True
lr_max=0.01
workers=2

In [None]:
from pathlib import Path
import os

if not Path('data/tiny-imagenet-200').exists():
    os.system('wget http://cs231n.stanford.edu/tiny-imagenet-200.zip -P data')
    os.system('unzip -qq data/tiny-imagenet-200.zip -d data')

DATA_DIR = 'data/tiny-imagenet-200' # Original images come in shapes of [3,64,64]

# Define training and validation data paths
TRAIN_DIR = os.path.join(DATA_DIR, 'train')
VALID_DIR = os.path.join(DATA_DIR, 'val')

traindata = ImageFolder(TRAIN_DIR, transform=T.Compose([
    T.RandomResizedCrop(64, scale=(scale, 1.0)),
    T.ToTensor(),
    T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
]))

valdata = ImageFolder(VALID_DIR, transform=T.Compose([
    T.Resize(64),
    T.CenterCrop(64),
    T.ToTensor(),
    T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
]))

trainloader = DataLoader(traindata, batch_size=batch_size, shuffle=True, num_workers=workers)
valloader = DataLoader(valdata, batch_size=batch_size, shuffle=False, num_workers=workers)

print(len(traindata))
print(len(valdata))

In [None]:
# cifar10_mean = (0.4914, 0.4822, 0.4465)
# cifar10_std = (0.2471, 0.2435, 0.2616)
# train_transform = T.Compose([
#     T.RandomResizedCrop(32, scale=(scale, 1.0), ratio=(1.0, 1.0)),
#     T.RandomHorizontalFlip(p=0.5),
#     T.RandAugment(num_ops=ra_n, magnitude=ra_m),
#     T.ColorJitter(jitter, jitter, jitter),
#     T.ToTensor(),
#     T.Normalize(cifar10_mean, cifar10_std),
#     T.RandomErasing(p=reprob)
# ])

# test_transform = T.Compose([
#     T.ToTensor(),
#     T.Normalize(cifar10_mean, cifar10_std)
# ])
# traindata = CIFAR10(root="data", train=True, download=True, transform=train_transform)
# testdata = CIFAR10(root="data", train=False, download=True, transform=test_transform)
# trainloader = DataLoader(traindata, batch_size=batch_size, shuffle=True, num_workers=workers)
# testloader = DataLoader(testdata, batch_size=batch_size, shuffle=False, num_workers=workers)

In [None]:
def get_stats(model, get_time=True, record_time_len=10, verbose=False):
    opt = optim.AdamW(model.parameters(), lr=lr_max, weight_decay=wd)
    criterion = nn.CrossEntropyLoss()
    max_mem = 0.
    transfered = []
    step_time = []
    record_transfers = 2
    record_mem = 3
    record_time = list(range(4, 4+record_time_len))
    end_step = max(4+record_time_len, record_mem) if get_time else record_mem
    if verbose:
        print(f"batch_size: {batch_size}, width: {model.width}, depth: {model.depth}")
        print(f"Total params: {sum(p.numel() for p in model.parameters()):,}")
        print(f"Total params size in gb: {sum(p.element_size()*p.nelement() for p in model.parameters())/1024**3:.4f} GB")
    for i, (X, y) in enumerate(trainloader):
        if i in record_time: start_step = time.time()
        if i == record_mem: torch.cuda.reset_peak_memory_stats()
        model.train()
        X, y = X.cuda(), y.cuda()
        if i == record_transfers:
            transfered.append(X.element_size() * X.nelement())
            transfered.append(y.element_size() * y.nelement())

        # lr = lr_schedule(epoch + (i + 1)/len(trainloader))
        # opt.param_groups[0].update(lr=lr)

        opt.zero_grad()
        # with torch.cuda.amp.autocast():
        output = model(X)
        loss = criterion(output, y)
        if i == record_transfers:
            transfered.append(loss.element_size() * loss.nelement())
            transfered.append(model(X, get_transfer=True))

        loss.backward()
        if clip_norm:
            # scaler.unscale_(opt)
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        # scaler.step(opt)
        # scaler.update()
        opt.step()
        # print(f"step {i} of {len(trainloader)}")
        if i in record_time: step_time.append(time.time() - start_step)
        if i == record_mem: max_mem = torch.cuda.max_memory_allocated(0)/1024**3
        if verbose:
            if i == record_time[-1]: print(f'avg step time: {np.mean(step_time):.4f} +- {np.std(step_time)}s')
            if i == record_mem:
                print(f'max_mem: {max_mem:.4f} GB')
            if i == record_transfers:
                print(f'transfered: {np.array(transfered).mean()/1024:.4f} kB')
        if i == end_step: break
    del model
    return max_mem, transfered, step_time

_ = get_stats(ResMlp(256, 1024, 2**12*3, 1000).cuda(), verbose=True)

In [None]:
def get_model_size(mem_limit, hdim=1024, gch=False, offload=False):
    # binary search to find depth to fit in mem_limit
    init_alloc = torch.cuda.memory_allocated(0)/1024**3
    # print(f"memory allocated before model: {init_alloc}GB")
    def get_stats_closure(depth, hdim, gch=gch, offload=offload):
        if offload:
            model = ResMlp(depth, hdim, 2**12*3, 1000, offload=True)
        else:
            model = ResMlp(depth, hdim, 2**12*3, 1000, gch=gch).cuda()
        return get_stats(model, get_time=False)[0] - init_alloc
        return get_stats(MlpMixer(num_blocks=depth, embed_dim=hdim, grad_checkpointing=gch).cuda(), get_time=False)[0] - init_alloc
    depth = 1
    while True:
        forward_mem = get_stats_closure(depth, hdim, gch)
        # print(f"depth: {depth}, hdim: {hdim}, forward_mem: {forward_mem:.4f}GB")
        if forward_mem > mem_limit:
            if depth == 1:
                hdim //= 2
                continue
            else:
                break
        depth *= 2
    scale = depth // 2
    depth -= scale
    while True:
        forward_mem = get_stats_closure(depth, hdim, gch)
        # print(f"depth: {depth}, hdim: {hdim}, forward_mem: {forward_mem:.4f}GB")
        scale //= 2
        if forward_mem > mem_limit: depth -= scale
        else: depth += scale
        if scale == 1: break
    forward_mem = get_stats_closure(depth, hdim, gch)
    # print(f"depth: {depth}, hdim: {hdim}, forward_mem: {forward_mem:.4f}GB")
    if forward_mem > mem_limit: depth -= 1
    if offload:
        model = ResMlp(depth, hdim, 2**12*3, 1000, offload=True)
    else:
        model = ResMlp(depth, hdim, 2**12*3, 1000, gch=gch).cuda()
    return depth, hdim, get_stats(model, get_time=True), init_alloc

def process_stats(out):
    depth,hdim,stats,init_alloc = out
    max_mem, transfered, step_time = stats
    return depth, hdim,max_mem - init_alloc, f"{np.array(transfered).sum()/1024} Kb", f"{np.mean(step_time)} ± {np.std(step_time)}"

print("mem_budget: 1gb, baseline", process_stats(get_model_size(1)))
print("mem_budget: 1gb, checkpointing", process_stats(get_model_size(1, gch=True)))
print("mem_budget: 1gb, offloading", process_stats(get_model_size(1, offload=True)))
print()

# print("mem_budget: 2gb, baseline", process_stats(get_model_size(2)))
# print("mem_budget: 2gb, checkpointing", process_stats(get_model_size(2, gch=True)))
# print("mem_budget: 2gb, offloading", process_stats(get_model_size(2, offload=True)))
# print()

# print("mem_budget: 4gb, baseline", process_stats(get_model_size(4)))
# print("mem_budget: 4gb, checkpointing", process_stats(get_model_size(4, gch=True)))
# print("mem_budget: 4gb, offloading", process_stats(get_model_size(4, offload=True)))

In [None]:
# Measure the latency to move data from CPU to GPU, and GPU to CPU
print("Measuring data transfer latency...")
bandwidth = []
trange = np.arange(1,21)
for i in trange:
    data_amt = 1 << i
    data = torch.randn(data_amt)
    print(f"Data amount: {data_amt} = {data_amt*data.element_size()/2**20:.4f}MB = 2^{i} bytes")
    cudata = data.cuda()
    for _ in range(10):
        cudata.mean().cpu().numpy()
    baseline = np.array(timeit.repeat(lambda: cudata.mean().cpu().numpy(), number=10000, repeat=7)) / 10000
    for _ in range(10):
        data.cuda().mean().cpu().numpy()
    real = np.array(timeit.repeat(lambda: data.cuda().mean().cpu().numpy(), number=10000, repeat=7)) / 10000
    bandwidth.append((data_amt * data.element_size())/(2**20 * (real-baseline)))
    print(f"bandwidth: {np.mean(bandwidth[-1]):.4e} ± {np.std(bandwidth[-1]):.4e} Mb/s")
    del data, cudata
    print()

In [None]:
import matplotlib.pyplot as plt
npband = np.array(bandwidth)
plt.plot(4*np.float32(2)**trange/2**20, npband.mean(1), label="mean")
plt.fill_between(4*np.float32(2)**trange/2**20, npband.mean(1)-npband.std(1), npband.mean(1)+npband.std(1), alpha=0.5, label="std")
plt.xlabel("Data amount Mb")
plt.ylabel("Bandwidth (Mb/s)")
plt.savefig("bandwidth_lut.pdf")

In [None]:
init_alloc = torch.cuda.memory_allocated(0)/1024**3
init_alloc

In [None]:
depth = 14
hdim = 1024
# model = ConvMixer(hdim, depth, kernel_size=9, patch_size=7, n_classes=1000)
model = MlpMixer(num_blocks=depth, embed_dim=hdim).cuda() 
lr_schedule = lambda t: np.interp([t], [0, epochs*2//5, epochs*4//5, epochs], 
                                  [0, lr_max, lr_max/20.0, 0])[0]

opt = optim.AdamW(model.parameters(), lr=lr_max, weight_decay=wd)
criterion = nn.CrossEntropyLoss()
scaler = torch.cuda.amp.GradScaler()

In [None]:
print(f"batch_size: {batch_size}, hdim: {hdim}, depth: {depth}")
print(f"Total params: {sum(p.numel() for p in model.parameters()):,}")
print(f"Total params size in gb: {sum(p.element_size()*p.nelement() for p in model.parameters())/1024**3:.4f}GB")
max_mem = 0.
transfered = 0
step_time = []
record_mem = 3
record_time = list(range(4, 4+100))
for epoch in range(epochs):
    start = time.time()
    train_loss, train_acc, n = 0, 0, 0
    for i, (X, y) in enumerate(trainloader):
        if i in record_time: start_step = time.time()
        if i == record_mem: torch.cuda.reset_peak_memory_stats(0)
        model.train()
        X, y = X.cuda(), y.cuda()
        if i == record_mem:
            transfered += X.element_size() * X.nelement()
            transfered += y.element_size() * y.nelement()

        lr = lr_schedule(epoch + (i + 1)/len(trainloader))
        opt.param_groups[0].update(lr=lr)

        opt.zero_grad()
        # with torch.cuda.amp.autocast():
        output = model(X)
        loss = criterion(output, y)
        if i == record_mem:
            transfered += loss.element_size() * loss.nelement()

        loss.backward()
        if clip_norm:
            # scaler.unscale_(opt)
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        # scaler.step(opt)
        # scaler.update()
        opt.step()

        train_loss += loss.item() * y.size(0)
        train_acc += (output.max(1)[1] == y).sum().item()
        n += y.size(0)
        # print(f"step {i} of {len(trainloader)}")
        if i in record_time: step_time.append(time.time() - start_step)
        if i == record_time[-1]: print(f'avg step time: {np.mean(step_time):.4f} +- {np.std(step_time)}s')

        if i == record_mem:
            max_mem = torch.cuda.max_memory_allocated(0)/1024**3
            print(f'max_mem: {max_mem:.4f} GB, transfered: {transfered/1024**2:.4f} MB')
        
    model.eval()
    test_acc, m = 0, 0
    with torch.no_grad():
        for i, (X, y) in enumerate(testloader):
            X, y = X.cuda(), y.cuda()
            with torch.cuda.amp.autocast():
                output = model(X)
            test_acc += (output.max(1)[1] == y).sum().item()
            m += y.size(0)

    print(f'[ConvMixer] Epoch: {epoch} | Train Acc: {train_acc/n:.4f}, Test Acc: {test_acc/m:.4f}, Time: {time.time() - start:.1f}, lr: {lr:.6f}')


In [None]:
import torch
import math
from models import Residual
from torch.utils import checkpoint

class OffloadCheckpoint(torch.autograd.Function):

    @staticmethod
    def forward(ctx, fn, x):
        with torch.no_grad():
            y = fn.cuda()(x)
        ctx.run_function = fn.cpu()
        ctx.inp = x.detach().cpu()
        return y

    @staticmethod
    def backward(ctx, grad):
        x = ctx.inp
        z = x.detach()
        z.requires_grad = True
        x = z
        x = x.cuda()
        x.requires_grad = True
        with torch.enable_grad():
            outputs = ctx.run_function.cuda()(x)
        grad = grad.cuda()
        torch.autograd.backward(outputs, grad)
        ctx.run_function.cpu()
        x_grad = x.grad.to("cuda:0")
        return None, x_grad


Block = lambda hdim: Residual(
    nn.Sequential(
        nn.Linear(hdim, hdim),
        nn.GELU(),
    )
)


class ResMlp(nn.Module):

    def __init__(self, depth, width, input_dim, out_dim, gch=False, offload=False):
        super().__init__()
        self.depth = depth
        self.width = width
        self.stem = nn.Sequential(nn.Flatten(), nn.Linear(input_dim, width))
        self.gch = gch
        self.offload = offload

        chunk_size = int(math.sqrt(depth))
        num_chunks = depth // chunk_size
        leftover = depth % chunk_size
        if self.gch or self.offload:
            self.blocks = nn.ModuleList([
                nn.Sequential(
                    *[Block(width) for _ in range(chunk_size)]
                ) for _ in range(num_chunks)
            ] + [nn.Sequential(
                    *[Block(width) for _ in range(leftover)]
            )])
        else:
            self.blocks = nn.ModuleList([Block(width) for _ in range(depth)])
        self.head = nn.Linear(width, out_dim)
    
    def forward(self, x, get_transfer=False):
        transfered = 0
        if self.offload:
            x = self.stem.cuda()(x)
            for block in self.blocks:
                with torch.autograd.graph.save_on_cpu(pin_memory=True):
                    x = OffloadCheckpoint.apply(block, x)
                if get_transfer:
                    transfered += x.element_size() * x.nelement() + sum([p.element_size() * p.nelement() for p in block.parameters()])
            x = self.head.cuda()(x)
            if get_transfer:
                return transfered
        else:
            if get_transfer:
                return transfered
            x = self.stem(x)
            if self.gch:
                for block in self.blocks:
                    x = checkpoint.checkpoint(block, x)
            else:
                for block in self.blocks:
                    x = block(x)
            x = self.head(x)
        return x

In [None]:
depth = 498
width = 1024
print(f"depth: {depth}, width: {width}")

X,y = next(iter(trainloader))
X = X.cuda()

def test_model(name, mod_model, init_alloc, baseline=None):
    print(name)
    print(f"init alloc: {init_alloc/1024**3:.4f} GB")
    if baseline is not None:
        for p, q in zip(baseline.parameters(), mod_model.parameters()):
            q.data = p.data.to(q.device)
    torch.cuda.reset_peak_memory_stats(0)
    loss = nn.CrossEntropyLoss()(mod_model(X), y.cuda())
    loss.backward()
    max_mem = torch.cuda.max_memory_allocated(0) - init_alloc
    print(f"used: {max_mem/2**30} GB")
    res = [p.grad.clone().cpu() for p in mod_model.parameters()]
    warmup = timeit.timeit(lambda: nn.CrossEntropyLoss()(mod_model(X), y.cuda()).backward, number=10)
    speed = timeit.timeit(lambda: nn.CrossEntropyLoss()(mod_model(X), y.cuda()).backward, number=10)/10
    print(f"speed: {speed:.4f} s")
    transfered = mod_model(X, get_transfer=True) + X.element_size() * X.nelement() + y.element_size() * y.nelement() + loss.element_size() * loss.nelement()
    print(f"transfered: {transfered/2**20} MB")
    if baseline:
        del mod_model, p, q, loss
    else:
        del loss
    print()
    return res

init_alloc = torch.cuda.memory_allocated(0)
baseline = ResMlp(depth, width, 2**12*3, 1000).cuda()
base = test_model("Baseline", baseline, init_alloc)

print("Naive")
init_alloc = torch.cuda.memory_allocated(0)
print(f"init alloc: {init_alloc/1024**3:.4f} GB")
model_naive = ResMlp(depth, width, 2**12*3, 1000)
for p, q in zip(baseline.parameters(), model_naive.parameters()):
    q.data = p.data.cpu()
torch.cuda.reset_peak_memory_stats(0)
loss = nn.CrossEntropyLoss()(model_naive(X.cpu()), y)
loss.backward()
max_mem = torch.cuda.max_memory_allocated(0) - init_alloc
print(f"used: {max_mem/2**30} GB")
speed = timeit.timeit(lambda: nn.CrossEntropyLoss()(model_naive(X.cpu()), y).backward, number=10)/10
print(f"speed: {speed:.4f} s")
naive = [p.grad.clone().cpu() for p in model_naive.parameters()]
del model_naive, p, q, loss
print()

init_alloc = torch.cuda.memory_allocated(0)
gch = test_model("Gradient Checkpointing", ResMlp(depth, width, 2**12*3, 1000, gch=True).cuda(), init_alloc, baseline=baseline)
init_alloc = torch.cuda.memory_allocated(0)
offload = ResMlp(depth, width, 2**12*3, 1000, offload=True)
offload = test_model("Gradient Checkpointing with Offload", offload, init_alloc, baseline=baseline)

# init_alloc = torch.cuda.memory_allocated(0)
# bigger = ResMlp(114, width, 2**12*3, 1000, offload=True)
# test_model("Gradient Checkpointing with Offload Bigger", bigger, init_alloc)
# del bigger

print(all([torch.allclose(base[i], gch[i]) for i in range(len(base))]))
print(all([torch.allclose(base[i], offload[i]) for i in range(len(base))]))
del baseline
del X, y
print(torch.cuda.memory_allocated(0)/2**30)

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

df = {
"1GB": {
    # "depth": 114, "width": 1024,
    "Baseline": {
        "used": 0.9930810928344727,
        "speed": 0.0070,
        "transferred": 0.09376907348632812,
    },

    "CPU Only": {
        "used": 0.0,
        "speed": 0.0220,
        "transferred": 0.0
    },

    "Gradient Checkpointing": {
        "used": 0.49654483795166016,
        "speed": 0.0067,
        "transferred": 0.09376907348632812,
    },

    "Gradient Checkpointing with Offload": {
        "used": 0.13294696807861328,
        "speed": 0.2174,
        "transferred": 456.6328315734863,
    },
},

"2GB": {
    # "depth": 241, "width": 1024,
    "Baseline": {
        "used": 1.9862375259399414,
        "speed": 0.0145,
        "transferred": 0.09376907348632812,
    },

    "CPU Only": {
        "used": 0.0,
        "speed": 0.0441,
        "transferred": 0.0,
    },

    "Gradient Checkpointing": {
        "used": 0.9931230545043945,
        "speed": 0.0131,
        "transferred": 0.09376907348632812,
    },

    "Gradient Checkpointing with Offload": {
        "used": 0.17204761505126953,
        "speed": 0.4588,
        "transferred": 965.1679878234863,
    },
},

"4GB": {
    # "depth": 498, "width": 1024,
    "Baseline": {
        "used": 3.9960107803344727,
        "speed": 0.0300,
        "transferred": 0.09376907348632812,
    },

    "CPU Only": {
        "used": 0.0,
        "speed": 0.0888,
        "transferred": 0.0,
    },

    "Gradient Checkpointing": {
        "used": 1.9980096817016602,
        "speed": 0.0264,
        "transferred": 0.09376907348632812,
    },

    "Gradient Checkpointing with Offload": {
        "used": 0.22678852081298828,
        "speed": 0.9462,
        "transferred": 1994.2187690734863,
    },
}
}

In [None]:
pd.DataFrame.from_dict(df)

In [None]:
print(pd.concat({k: pd.DataFrame(v).T for k, v in df.items()}, axis=0).to_latex())