In [1]:
%mkdir results
%cd /content/results
%mkdir first_train
%mkdir second_train
%mkdir whole_train
%cd ..

/content/results
/content


In [2]:
%%writefile dataloader.py
from torchvision.datasets import CIFAR100
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader, Dataset

class ClassFilter(Dataset):
    def __init__(self, examples, filter_function):
        self.data = [item for item in examples if filter_function(item[1])]

    def __getitem__(self, item):
        return self.data[item]

    def __len__(self):
        return len(self.data)

def makeLoader(pos: str, train: bool, filter_function):
    return DataLoader(ClassFilter(CIFAR100(pos, train=train, transform=ToTensor(), download=True), filter_function), batch_size=64, shuffle=True)


Writing dataloader.py


In [3]:
%%writefile palib.py
import torch
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Module, Sequential, BatchNorm2d, ReLU

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def extend_tensor(in_tensor: torch.Tensor, in_size: list, out_size: list):
    if list(in_tensor.shape) != in_size:
        raise ValueError("in_tensor and in_size mismatch")
    if len(in_size) != len(out_size):
        raise ValueError(f"cannot extend tensor size of {in_size} to {out_size}")
    for i in range(len(in_size)):
        if in_size[i] == out_size[i]:
            continue
        in_tensor = torch.cat((in_tensor, torch.zeros(*[(in_size[j] if j > i else out_size[j]) if i != j else out_size[i] - in_size[i] for j in range(len(in_size))]).to(device)), dim=i)
    return in_tensor.clone().detach().requires_grad_(True)

class PAConv2d(Conv2d):
    def __rshift__(self, other):
        with torch.no_grad():
            other.weight.copy_(extend_tensor(self.weight, list(self.weight.shape), list(other.weight.shape)))
            other.bias.copy_(extend_tensor(self.bias, list(self.bias.shape), list(other.bias.shape)))


class PALinear(Linear):
    def __rshift__(self, other):
        with torch.no_grad():
            other.weight.copy_(extend_tensor(self.weight, list(self.weight.shape), list(other.weight.shape)))
            other.bias.copy_(extend_tensor(self.bias, list(self.bias.shape), list(other.bias.shape)))

class PAResBlock(Module):
    def __init__(self, in_channels, out_channels, kernel_size):
        super(PAResBlock, self).__init__()
        self.kernel_size = kernel_size
        self.stride = 1
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.layers = Sequential(
            PAConv2d(self.in_channels, self.out_channels, kernel_size=self.kernel_size, stride=self.stride, padding=1),
            BatchNorm2d(num_features=self.out_channels),
            ReLU(),
            PAConv2d(self.out_channels, self.out_channels, kernel_size=self.kernel_size, stride=self.stride, padding=1),
            BatchNorm2d(num_features=self.out_channels),
        )
        self.relu = ReLU()
        if self.in_channels != self.out_channels:
            self.shortcut = Sequential(
                Conv2d(self.in_channels, self.out_channels, kernel_size=1, stride=self.stride, bias=False),
                BatchNorm2d(num_features=self.out_channels)
            )

    def forward(self, x):
        if self.in_channels != self.out_channels:
            return self.relu(self.shortcut(x) + self.layers(x))
        return self.relu(x + self.layers(x))

    def __rshift__(self, other):
        self.layers[0] >> other.layers[0]
        self.layers[3] >> other.layers[3]

def formResBlocks(in_channels, out_channels, nums):
    first_layer = PAResBlock(in_channels, out_channels, 3)
    other_layers = [PAResBlock(out_channels, out_channels, 3) for _ in range(nums-1)]
    return [first_layer, *other_layers]

def formLinears(linear):
    li = []
    for i, o in zip(linear[:-1], linear[1:]):
        li.append(PALinear(i, o))
        li.append(ReLU())
    return li[:-1]

class PAResNet(Module):
    def __init__(self, layer_nums: list, layer_channels: list, linear: list):
        super().__init__()
        if len(layer_nums) != len(layer_channels):
            raise ValueError("unexpected layer_nums and layer_channels size")

        li = []
        li.extend(formResBlocks(3, layer_channels[0], layer_nums[0]))
        for i in range(len(layer_nums)-1):
            li.append(MaxPool2d(2, 2))
            li.extend(formResBlocks(layer_channels[i], layer_channels[i+1], layer_nums[1]))

        self.blocks = Sequential(*li)
        self.linear = Sequential(Flatten(), *formLinears([layer_channels[-1]*(32 >> (len(layer_nums)-1))*(32 >> (len(layer_nums)-1))]+linear))

    def forward(self, x):
        return self.linear(self.blocks(x))

    def __rshift__(self, other):
        for i in range(len(self.blocks)):
            if isinstance(self.blocks[i], PAResBlock):
                self.blocks[i] >> other.blocks[i]

        for i in range(len(self.linear)):
            if isinstance(self.linear[i], PALinear):
                self.linear[i] >> other.linear[i]

Writing palib.py


In [4]:
#import librarys
from dataloader import makeLoader
from palib import PAResNet
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
import torch
from matplotlib import pyplot as plt
from tqdm import tqdm
import time

In [5]:
#check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
#load datas
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
first_train_loader = makeLoader(r"\data", True, lambda x: x < 75)
first_test_loader = makeLoader(r"\data", False, lambda x: x < 75)
second_test_loader = makeLoader(r"\data", False, lambda x: x >= 75)
train_loader = makeLoader(r"\data", True, lambda x: True)
test_loader = makeLoader(r"\data", False, lambda x: True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to \data/cifar-100-python.tar.gz


100%|██████████| 169001437/169001437 [00:12<00:00, 13175348.92it/s]


Extracting \data/cifar-100-python.tar.gz to \data
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [7]:
#define loss function
loss_fn = CrossEntropyLoss()

In [8]:
#define other useful functions
def calc_loss_and_accuracy(output, labels):
    loss = loss_fn(output, labels)
    _, pred = torch.max(output, 1)
    return loss, torch.sum((pred == labels).squeeze()), torch.numel(pred)

def test_(model, loader):
    ss, cnt, ls = 0, 0, 0.0
    model.eval()
    for batch in tqdm(loader, unit="batch", total=len(loader)):
        inputs, labels = batch[0].to(device), batch[1].to(device)
        output = model(inputs)
        loss, acc, cor = calc_loss_and_accuracy(output, labels)
        ss += acc.item()
        cnt += cor
        ls += loss.item() * cor
    return ss, cnt, ls

In [9]:
#training and testing: pa
st_time = time.time()

train_loss_pa = []
train_acc_pa = []
train_t_pa = []
train_acc_t_pa = []
test_loss_pa = []
test_acc_pa = []
test_t_pa = []
test_acc_t_pa = []
test_acc_sub_pa = []

model = PAResNet([2, 2, 2], [48, 96, 192], [384, 192, 75]).to(device)
print(list(map(lambda x: x.size(), model.parameters())))
optimizer = Adam(model.parameters(), lr=1e-4)

for epoch in range(40):
    ss, cnt = 0, 0
    print(epoch)
    model.train()
    for batch in tqdm(first_train_loader, unit="batch", total=len(first_train_loader)):
        inputs, labels = batch[0].to(device), batch[1].to(device)
        optimizer.zero_grad()
        output = model(inputs)
        loss, acc, cor = calc_loss_and_accuracy(output, labels)
        loss.backward()
        optimizer.step()
        ss += acc.item()
        cnt += cor
        train_loss_pa.append(loss.item())
        train_t_pa.append(time.time() - st_time)
    train_acc_pa.append(ss / cnt * 100)
    train_acc_t_pa.append(time.time() - st_time)

    ss, cnt, ls = test_(model, first_test_loader)
    test_loss_pa.append(ls / cnt)
    test_t_pa.append(time.time() - st_time)
    test_acc_pa.append(ss / cnt * 100)
    test_acc_t_pa.append(time.time() - st_time)
    test_acc_sub_pa.append(0)
    torch.save(model.state_dict(), f"./results/first_train/params{epoch}.pt")

new_model = PAResNet([2, 2, 2], [64, 128, 256], [512, 256, 100]).to(device)
print(list(map(lambda x: x.size(), new_model.parameters())))
model >> new_model
del model
model = new_model
optimizer = Adam(model.parameters(), lr=1e-4)

for epoch in range(40, 60):
    ss, cnt = 0, 0
    print(epoch)
    model.train()
    for batch in tqdm(train_loader, unit="batch", total=len(train_loader)):
        inputs, labels = batch[0].to(device), batch[1].to(device)
        optimizer.zero_grad()
        output = model(inputs)
        loss, acc, cor = calc_loss_and_accuracy(output, labels)
        loss.backward()
        optimizer.step()
        ss += acc.item()
        cnt += cor
        train_loss_pa.append(loss.item())
        train_t_pa.append(time.time() - st_time)
    train_acc_pa.append(ss / cnt * 100)
    train_acc_t_pa.append(time.time() - st_time)

    ss, cnt, ls = test_(model, test_loader)
    test_loss_pa.append(ls / cnt)
    test_t_pa.append(time.time() - st_time)
    test_acc_pa.append(ss / cnt * 100)
    test_acc_t_pa.append(time.time() - st_time)

    ss, cnt, ls = test_(model, second_test_loader)
    test_acc_sub_pa.append(ss / cnt * 100)
    torch.save(model.state_dict(), f"./results/second_train/params{epoch}.pt")

plt.figure(1)
plt.plot(train_t_pa, train_loss_pa, label='train loss: PA')
plt.plot(test_t_pa, test_loss_pa, label='test loss: PA')
plt.figure(2)
plt.plot(train_acc_t_pa, train_acc_pa, label='train accuracy: PA')
plt.plot(test_acc_t_pa, test_acc_pa, label='test accuracy: PA')
plt.plot(test_acc_t_pa, test_acc_sub_pa, label='later 25 class test accuracy: PA')

[torch.Size([48, 3, 3, 3]), torch.Size([48]), torch.Size([48]), torch.Size([48]), torch.Size([48, 48, 3, 3]), torch.Size([48]), torch.Size([48]), torch.Size([48]), torch.Size([48, 3, 1, 1]), torch.Size([48]), torch.Size([48]), torch.Size([48, 48, 3, 3]), torch.Size([48]), torch.Size([48]), torch.Size([48]), torch.Size([48, 48, 3, 3]), torch.Size([48]), torch.Size([48]), torch.Size([48]), torch.Size([96, 48, 3, 3]), torch.Size([96]), torch.Size([96]), torch.Size([96]), torch.Size([96, 96, 3, 3]), torch.Size([96]), torch.Size([96]), torch.Size([96]), torch.Size([96, 48, 1, 1]), torch.Size([96]), torch.Size([96]), torch.Size([96, 96, 3, 3]), torch.Size([96]), torch.Size([96]), torch.Size([96]), torch.Size([96, 96, 3, 3]), torch.Size([96]), torch.Size([96]), torch.Size([96]), torch.Size([192, 96, 3, 3]), torch.Size([192]), torch.Size([192]), torch.Size([192]), torch.Size([192, 192, 3, 3]), torch.Size([192]), torch.Size([192]), torch.Size([192]), torch.Size([192, 96, 1, 1]), torch.Size([192

100%|██████████| 586/586 [00:22<00:00, 25.55batch/s]
100%|██████████| 118/118 [00:01<00:00, 101.98batch/s]


1


 34%|███▍      | 199/586 [00:07<00:13, 28.09batch/s]


KeyboardInterrupt: 

In [None]:
#training and testing: cnn
model = PAResNet([2, 2, 2], [64, 128, 256], [512, 256, 100]).to(device)

st_time = time.time()
optimizer = Adam(model.parameters(), lr=1e-4)

train_loss_cnn = []
train_acc_cnn = []
train_t_cnn = []
train_acc_t_cnn = []
test_loss_cnn = []
test_acc_cnn = []
test_t_cnn = []
test_acc_t_cnn = []
test_acc_sub_cnn = []

for epoch in range(50):
    ss, cnt = 0, 0
    print(epoch)
    model.train()
    for batch in tqdm(train_loader, unit="batch", total=len(train_loader)):
        inputs, labels = batch[0].to(device), batch[1].to(device)
        optimizer.zero_grad()
        output = model(inputs)
        loss, acc, cor = calc_loss_and_accuracy(output, labels)
        loss.backward()
        optimizer.step()
        ss += acc.item()
        cnt += cor
        train_loss_cnn.append(loss.item())
        train_t_cnn.append(time.time() - st_time)
    train_acc_cnn.append(ss / cnt * 100)
    train_acc_t_cnn.append(time.time() - st_time)

    ss, cnt, ls = test_(model, test_loader)
    test_loss_cnn.append(ls / cnt)
    test_t_cnn.append(time.time() - st_time)
    test_acc_cnn.append(ss / cnt * 100)
    test_acc_t_cnn.append(time.time() - st_time)

    ss, cnt, ls = test_(model, second_test_loader)
    test_acc_sub_cnn.append(ss / cnt * 100)
    torch.save(model.state_dict(), f"./results/whole_train/params{epoch}.pt")

plt.figure(1)
plt.plot(train_t_cnn, train_loss_cnn, label='train loss: CNN')
plt.plot(test_t_cnn, test_loss_cnn, label='test loss: CNN')
plt.figure(2)
plt.plot(train_acc_t_cnn, train_acc_cnn, label='train accuracy: CNN')
plt.plot(test_acc_t_cnn, test_acc_cnn, label='test accuracy: CNN')
plt.plot(test_acc_t_cnn, test_acc_sub_cnn, label='later 25 class test accuracy: CNN')

In [None]:
#show graphs
plt.figure(1)
plt.xlabel('Time (s)')
plt.ylabel('Loss')
plt.title('Training Loss / Test Loss')
plt.legend()
plt.figure(2)
plt.xlabel('Time (s)')
plt.ylabel('Accuracy (%)')
plt.title('Training Accuracy / Test Accuracy')
plt.legend()
plt.show()

In [None]:
#define classes
names = ["apple", "aquarium_fish", "baby", "bear", "beaver", "bed", "bee",
    "beetle", "bicycle", "bottle", "bowl", "boy", "bridge",
    "bus",
    "butterfly",
    "camel",
    "can",
    "castle",
    "caterpillar",
    "cattle",
    "chair",
    "chimpanzee",
    "clock",
    "cloud",
    "cockroach",
    "couch",
    "crab",
    "crocodile",
    "cup",
    "dinosaur",
    "dolphin",
    "elephant",
    "flatfish",
    "forest",
    "fox",
    "girl",
    "hamster",
    "house",
    "kangaroo",
    "keyboard",
    "lamp",
    "lawn_mower",
    "leopard",
    "lion",
    "lizard",
    "lobster",
    "man",
    "maple_tree",
    "motorcycle",
    "mountain",
    "mouse",
    "mushroom",
    "oak_tree",
    "orange",
    "orchid",
    "otter",
    "palm_tree",
    "pear",
    "pickup_truck",
    "pine_tree",
    "plain",
    "plate",
    "poppy",
    "porcupine",
    "possum",
    "rabbit",
    "raccoon",
    "ray",
    "road",
    "rocket",
    "rose",
    "sea",
    "seal",
    "shark",
    "shrew",
    "skunk",
    "skyscraper",
    "snail",
    "snake",
    "spider",
    "squirrel",
    "streetcar",
    "sunflower",
    "sweet_pepper",
    "table",
    "tank",
    "telephone",
    "television",
    "tiger",
    "tractor",
    "train",
    "trout",
    "tulip",
    "turtle",
    "wardrobe",
    "whale",
    "willow_tree",
    "wolf",
    "woman",
    "worm"
]

In [None]:
#testing model
model = None #load model
