In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import datasets, transforms
from tqdm.notebook import tqdm
from tensorboardX import SummaryWriter
from torch.optim.lr_scheduler import LambdaLR, StepLR

In [17]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [18]:
class Config:

    batch_size = 64
    epochs = 300
    lr = 1e-3

    channel = 3

    height = 32
    width = 32

    data_root = '../dataset/cifar10'

    dropout_rate = 0.1
    attn_dropout = 0

    patch_size = 4
    num_patches = int((height * width) / (patch_size ** 2))

    layers = 12
    embedding_d = 768
    mlp_size = 1024
    heads = 8

    num_classes = 10

    log_f = 100

In [19]:
transform = transforms.Compose([
  transforms.Resize(224),
  transforms.CenterCrop(224),
  transforms.ToTensor(),
  transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

test_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor()
])

In [20]:
trainset = torchvision.datasets.CIFAR10(root=Config.data_root, train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=Config.batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root=Config.data_root, train=False,
                                       download=True, transform=test_transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=Config.batch_size,
                                         shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [22]:
!pip install timm





In [23]:
import timm

model = timm.create_model("vit_base_patch16_224", pretrained=True)
model.head = nn.Linear(model.head.in_features, 10)

In [24]:
# model = Model().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=Config.lr)
criterion = nn.CrossEntropyLoss()

In [25]:
scheduler = StepLR(optimizer, 5)

In [26]:
writer = SummaryWriter()

In [27]:
def test(epoch):

    model.eval()

    test_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():

        for batch_idx, (inputs, targets) in enumerate(tqdm(testloader)):

            if Config.batch_size != len(inputs):
                inputs = inputs.repeat(Config.batch_size // inputs.size(0) + 1, 1, 1, 1)[:Config.batch_size]
                targets = targets.repeat(Config.batch_size // targets.size(0) + 1)[:Config.batch_size]

            inputs, targets = inputs.to(device), targets.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    acc = 100. * correct / total

    print(f'Epoch {epoch} val loss: {test_loss:.5f}, test acc: {(acc):.5f}')

    return test_loss, acc

In [28]:
def train(epoch):

    model.train()

    train_loss = 0
    correct = 0
    total = 0

    for batch_idx, (inputs, targets) in enumerate(tqdm(trainloader)):

        if Config.batch_size != len(inputs):
            inputs = inputs.repeat(Config.batch_size // inputs.size(0) + 1, 1, 1, 1)[:Config.batch_size]
            targets = targets.repeat(Config.batch_size // targets.size(0) + 1)[:Config.batch_size]

        inputs, targets = inputs.to(device), targets.to(device)

        outputs = model(inputs)
        loss = criterion(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        acc = 100.*correct/total

        if batch_idx % Config.log_f == 0:
            print(f'Epoch {epoch}, batch index {batch_idx} || train loss: {train_loss/(batch_idx+1)}, train acc: {acc}')

    return train_loss/(batch_idx+1), acc

In [29]:
model = model.to(device)

In [30]:
for epoch in range(Config.epochs):

    train_loss, train_acc = train(epoch)
    test_loss, test_acc = test(epoch)

    scheduler.step()

    writer.add_scalar('train/loss', train_loss, epoch)
    writer.add_scalar('train/acc', train_acc, epoch)
    writer.add_scalar('test/loss', test_loss, epoch)
    writer.add_scalar('test/acc', test_acc, epoch)

  0%|          | 0/782 [00:00<?, ?it/s]

Epoch 0, batch index 0 || train loss: 2.1957013607025146, train acc: 25.0
Epoch 0, batch index 100 || train loss: 2.356283929088328, train acc: 19.198638613861387
Epoch 0, batch index 200 || train loss: 2.1450099440949475, train acc: 23.34421641791045
Epoch 0, batch index 300 || train loss: 2.0265084465476764, train acc: 26.28737541528239
Epoch 0, batch index 400 || train loss: 1.9504247966252657, train acc: 28.56920199501247
Epoch 0, batch index 500 || train loss: 1.8881200804205949, train acc: 30.713572854291417
Epoch 0, batch index 600 || train loss: 1.8372360134680141, train acc: 32.43032445923461
Epoch 0, batch index 700 || train loss: 1.7999802239441158, train acc: 33.73083095577746


  0%|          | 0/157 [00:00<?, ?it/s]

Epoch 0 val loss: 451.13953, test acc: 16.47094


  0%|          | 0/782 [00:00<?, ?it/s]

Epoch 1, batch index 0 || train loss: 1.4259912967681885, train acc: 45.3125
Epoch 1, batch index 100 || train loss: 1.5125513725941724, train acc: 44.662747524752476
Epoch 1, batch index 200 || train loss: 1.5016722115711194, train acc: 45.01710199004975
Epoch 1, batch index 300 || train loss: 1.4901277369439008, train acc: 45.34364617940199
Epoch 1, batch index 400 || train loss: 1.4907569222319452, train acc: 45.417705735660846
Epoch 1, batch index 500 || train loss: 1.4952077879877148, train acc: 45.2064620758483
Epoch 1, batch index 600 || train loss: 1.5068747864388388, train acc: 44.76393510815308
Epoch 1, batch index 700 || train loss: 1.5045953584295537, train acc: 44.76863409415121


  0%|          | 0/157 [00:00<?, ?it/s]

Epoch 1 val loss: 406.78453, test acc: 23.14889


  0%|          | 0/782 [00:00<?, ?it/s]

Epoch 2, batch index 0 || train loss: 1.3873785734176636, train acc: 50.0
Epoch 2, batch index 100 || train loss: 1.5145045754933122, train acc: 44.04393564356435
Epoch 2, batch index 200 || train loss: 1.512555035785656, train acc: 44.16977611940298
Epoch 2, batch index 300 || train loss: 1.5104061428494628, train acc: 44.49231727574751
Epoch 2, batch index 400 || train loss: 1.5078077854361023, train acc: 44.56047381546135
Epoch 2, batch index 500 || train loss: 1.5121596768468677, train acc: 44.45172155688623
Epoch 2, batch index 600 || train loss: 1.514761192231329, train acc: 44.30896422628952
Epoch 2, batch index 700 || train loss: 1.5161997964140692, train acc: 44.25151569186876


  0%|          | 0/157 [00:00<?, ?it/s]

Epoch 2 val loss: 491.79149, test acc: 14.60987


  0%|          | 0/782 [00:00<?, ?it/s]

Epoch 3, batch index 0 || train loss: 1.6806981563568115, train acc: 43.75
Epoch 3, batch index 100 || train loss: 1.540088154301785, train acc: 43.48700495049505
Epoch 3, batch index 200 || train loss: 1.5370163223636684, train acc: 43.27580845771144
Epoch 3, batch index 300 || train loss: 1.5302226392929736, train acc: 43.94206810631229
Epoch 3, batch index 400 || train loss: 1.5253061178021894, train acc: 44.155236907730675
Epoch 3, batch index 500 || train loss: 1.5251142650307297, train acc: 43.993263473053894
Epoch 3, batch index 600 || train loss: 1.5213022777522462, train acc: 44.32976289517471
Epoch 3, batch index 700 || train loss: 1.5184095762935752, train acc: 44.583630527817405


  0%|          | 0/157 [00:00<?, ?it/s]

Epoch 3 val loss: 493.75098, test acc: 14.63973


  0%|          | 0/782 [00:00<?, ?it/s]

Epoch 4, batch index 0 || train loss: 1.6345605850219727, train acc: 37.5
Epoch 4, batch index 100 || train loss: 1.5451524470112112, train acc: 44.198638613861384
Epoch 4, batch index 200 || train loss: 1.5348767030298414, train acc: 44.37189054726368
Epoch 4, batch index 300 || train loss: 1.525741913389922, train acc: 44.48193521594684
Epoch 4, batch index 400 || train loss: 1.5190741421278575, train acc: 44.720230673316706
Epoch 4, batch index 500 || train loss: 1.5159407064110457, train acc: 44.69498502994012
Epoch 4, batch index 600 || train loss: 1.5140990754728905, train acc: 44.80553244592346
Epoch 4, batch index 700 || train loss: 1.5146451844978606, train acc: 44.73965763195435


  0%|          | 0/157 [00:00<?, ?it/s]

Epoch 4 val loss: 454.75834, test acc: 17.31688


  0%|          | 0/782 [00:00<?, ?it/s]

Epoch 5, batch index 0 || train loss: 1.8784370422363281, train acc: 39.0625
Epoch 5, batch index 100 || train loss: 1.4212833843608894, train acc: 48.39108910891089
Epoch 5, batch index 200 || train loss: 1.3988735242862607, train acc: 49.214863184079604
Epoch 5, batch index 300 || train loss: 1.3853280940325157, train acc: 49.64181893687708
Epoch 5, batch index 400 || train loss: 1.3788451110930218, train acc: 49.797381546134666
Epoch 5, batch index 500 || train loss: 1.374570086092768, train acc: 49.90955588822355
Epoch 5, batch index 600 || train loss: 1.3625050406289378, train acc: 50.34057820299501
Epoch 5, batch index 700 || train loss: 1.35530711206662, train acc: 50.68429029957204


  0%|          | 0/157 [00:00<?, ?it/s]

Epoch 5 val loss: 481.69023, test acc: 19.44666


  0%|          | 0/782 [00:00<?, ?it/s]

Epoch 6, batch index 0 || train loss: 1.3238550424575806, train acc: 45.3125
Epoch 6, batch index 100 || train loss: 1.2906110021147397, train acc: 53.620049504950494
Epoch 6, batch index 200 || train loss: 1.2779726991012914, train acc: 54.24440298507463
Epoch 6, batch index 300 || train loss: 1.2786178913623392, train acc: 53.86731727574751
Epoch 6, batch index 400 || train loss: 1.2780176036970277, train acc: 53.88092269326683
Epoch 6, batch index 500 || train loss: 1.2758706510661844, train acc: 53.982659680638726
Epoch 6, batch index 600 || train loss: 1.2745647957043322, train acc: 53.89715058236273
Epoch 6, batch index 700 || train loss: 1.2719453185499139, train acc: 53.92965406562054


  0%|          | 0/157 [00:00<?, ?it/s]

Epoch 6 val loss: 480.64025, test acc: 20.09355


  0%|          | 0/782 [00:00<?, ?it/s]

Epoch 7, batch index 0 || train loss: 1.2916309833526611, train acc: 54.6875
Epoch 7, batch index 100 || train loss: 1.2101096363350896, train acc: 55.89418316831683
Epoch 7, batch index 200 || train loss: 1.2277047224898836, train acc: 55.30939054726368
Epoch 7, batch index 300 || train loss: 1.2289888231065187, train acc: 55.185838870431894
Epoch 7, batch index 400 || train loss: 1.2271409303767427, train acc: 55.43952618453865
Epoch 7, batch index 500 || train loss: 1.2311303694091158, train acc: 55.37986526946108
Epoch 7, batch index 600 || train loss: 1.2286573118854085, train acc: 55.37125623960067
Epoch 7, batch index 700 || train loss: 1.225120975610703, train acc: 55.47432239657632


  0%|          | 0/157 [00:00<?, ?it/s]

Epoch 7 val loss: 473.28336, test acc: 20.33240


  0%|          | 0/782 [00:00<?, ?it/s]

Epoch 8, batch index 0 || train loss: 1.4359169006347656, train acc: 48.4375
Epoch 8, batch index 100 || train loss: 1.1746225752452812, train acc: 57.162747524752476
Epoch 8, batch index 200 || train loss: 1.1773303587045243, train acc: 56.94185323383085
Epoch 8, batch index 300 || train loss: 1.1798384225249687, train acc: 57.09613787375415
Epoch 8, batch index 400 || train loss: 1.1779758766702286, train acc: 57.27088528678304
Epoch 8, batch index 500 || train loss: 1.1788081211482218, train acc: 57.29478542914172
Epoch 8, batch index 600 || train loss: 1.178576178637995, train acc: 57.43032445923461
Epoch 8, batch index 700 || train loss: 1.1788171559529705, train acc: 57.34218972895863


  0%|          | 0/157 [00:00<?, ?it/s]

Epoch 8 val loss: 481.47501, test acc: 21.13854


  0%|          | 0/782 [00:00<?, ?it/s]

Epoch 9, batch index 0 || train loss: 1.04513418674469, train acc: 59.375
Epoch 9, batch index 100 || train loss: 1.1457961148554736, train acc: 58.18378712871287
Epoch 9, batch index 200 || train loss: 1.1479916121829208, train acc: 58.06125621890547
Epoch 9, batch index 300 || train loss: 1.151470702747966, train acc: 57.91112956810631
Epoch 9, batch index 400 || train loss: 1.1589559891871979, train acc: 57.598192019950126
Epoch 9, batch index 500 || train loss: 1.1590094020266732, train acc: 57.72205588822355
Epoch 9, batch index 600 || train loss: 1.155465188121637, train acc: 57.91649334442596
Epoch 9, batch index 700 || train loss: 1.1518560116539327, train acc: 58.17804921540656


  0%|          | 0/157 [00:00<?, ?it/s]

Epoch 9 val loss: 465.68100, test acc: 21.93471


  0%|          | 0/782 [00:00<?, ?it/s]

Epoch 10, batch index 0 || train loss: 1.0078420639038086, train acc: 60.9375
Epoch 10, batch index 100 || train loss: 1.0800429759639325, train acc: 60.9375


KeyboardInterrupt: 