In [None]:
import torch
from utils.training import train_accuracy
import matplotlib.pyplot as plt
from os import listdir
from tqdm import tqdm, trange
from data_loaders.imagenet import Train, Val
import numpy as np
from utils.pruning import get_prune_percentage

In [None]:
def rule(x):
    a = x[10:-4]
    try: return int(a)
    except: return 0
base_dir = "saved_models/RN18_64_2"
all_models = listdir(base_dir)
real_models = sorted([m for m in all_models if m.startswith("real")], key = rule)
quat_models = sorted([m for m in all_models if m.startswith("quat")], key = rule)
print(f"{real_models = }\n{quat_models = }")

In [None]:
# training_generator = torch.utils.data.DataLoader(Train(), batch_size=256, num_workers=4)
validation_generator = torch.utils.data.DataLoader(Val(), batch_size=1024, num_workers=4)
GPU = torch.device("cuda:0")

In [None]:
get_prune_percentage(torch.load(f"{base_dir}/{real_models[0]}"))

In [None]:
# real_train_accs = []
real_test_accs = []
real_prune_percs = []
for model_name in tqdm(real_models, desc="Real models", unit="model"):
    model = torch.load(f"{base_dir}/{model_name}")
    real_prune_percs.append(get_prune_percentage(model)*100)
    # real_train_accs.append(train_accuracy(model, training_generator, GPU, 100))
    real_test_accs.append(train_accuracy(model, validation_generator, GPU))

# quat_train_accs = []
quat_test_accs = []
quat_prune_percs = []
for model_name in tqdm(quat_models, desc="Quat models", unit="model"):
    model = torch.load(f"{base_dir}/{model_name}")
    quat_prune_percs.append(get_prune_percentage(model)*25)
    # quat_train_accs.append(train_accuracy(model, training_generator, GPU, 100))
    quat_test_accs.append(train_accuracy(model, validation_generator, GPU))

In [None]:
plt.plot(real_prune_percs[:-1], real_test_accs[:-1], label="Real test acc")
plt.plot(quat_prune_percs[:-1], quat_test_accs[:-1], label="Quat test acc")

plt.xscale("log")
plt.legend()
plt.xlabel("Prune percentage")
plt.ylabel("Accuracy")
plt.gca().invert_xaxis()
# plt.xticks([0.25, 1, 5,  25, 100], ["0.25%", "1%", "5%", "25%", "100%"])
plt.xticks([0.39, 1.56, 6.25,  25, 100], ["0.39%", "1.56%", "6.25%", "25%", "100%"])
plt.ylim(10, 36)
plt.grid()

In [None]:
# https://arxiv.org/pdf/2301.04623.pdf
# https://openreview.net/pdf?id=K398CuAKVKB

In [17]:
import torch
import torchvision
from torchvision import transforms
from models.resnet_real import ResNet18 as Model

In [2]:
transform_train = transforms.Compose([
	transforms.RandomCrop(32, padding=4),
	transforms.RandomHorizontalFlip(),
	transforms.ToTensor(),
	transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
	transforms.ToTensor(),
	transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

In [6]:
a = torch.randn(1, 4, 32, 32).numpy()

In [None]:
transform_train(a).shape

In [11]:
trainset = torchvision.datasets.CIFAR100(root='/home/aritra/project/quatLT23/9_cifar100_RN18/cifar100', train=True, download=True, transform=transform_train)
training_generator = torch.utils.data.DataLoader(trainset, batch_size=100, shuffle=True, drop_last=True)

testset = torchvision.datasets.CIFAR100(root='/home/aritra/project/quatLT23/9_cifar100_RN18/cifar100', train=False, download=True, transform=transform_test)
validation_generator = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, drop_last=True)

Files already downloaded and verified
Files already downloaded and verified


In [38]:
import time
from tqdm import tqdm
import numpy as np
import torch

In [29]:
model = Model(3, 100).to("cuda:0")

In [46]:
batch_x, batch_y = next(iter(training_generator))

In [47]:
batch_x.shape

torch.Size([100, 3, 32, 32])

In [49]:
mat = torch.tensor(
            [
                [1, 0, 0, 0.299],
                [0, 1, 0, 0.587],
                [0, 0, 1, 0.144]
            ]
        )

In [63]:
batch_x = batch_x

In [54]:
batch_x.permute(0, 2, 3, 1)

torch.Size([100, 32, 32, 3])

In [64]:
torch.Tensor(np.dot(batch_x.numpy().transpose(0, 2, 3, 1), mat).transpose(0, 3, 1, 2)).float()

torch.Size([100, 4, 32, 32])

In [30]:
t0 = time.time()
for batch_x, batch_y in tqdm(training_generator):
    # print(batch_x.shape, batch_y.shape)
    # break
    
    model(batch_x.to("cuda:0"))
print(time.time() - t0)

100%|██████████| 500/500 [00:04<00:00, 102.07it/s]

4.900058269500732





In [66]:
# batch_x = torch.Tensor(np.dot(batch_x.numpy().transpose(0, 2, 3, 1), mat).transpose(0, 3, 1, 2)).float()

In [76]:
batch_x = np.random.randn(3, 32, 32)
batch_x.shape

(3, 32, 32)

In [78]:
np.dot(batch_x.transpose(1, 2, 0), mat).shape

(32, 32, 4)