In [2]:
!pip install einops

Collecting einops
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
     ---------------------------------------- 0.0/42.2 kB ? eta -:--:--
     ---------------------------------------- 42.2/42.2 kB 1.0 MB/s eta 0:00:00
Installing collected packages: einops
Successfully installed einops-0.6.1


In [2]:
import torch
import timm
import numpy as np

from einops import repeat, rearrange
from einops.layers.torch import Rearrange

from timm.models.layers import trunc_normal_
from timm.models.vision_transformer import Block

In [3]:
def random_indexes(size : int):
    forward_indexes = np.arange(size)
    np.random.shuffle(forward_indexes)
    backward_indexes = np.argsort(forward_indexes)
    return forward_indexes, backward_indexes

random_indexes(5)

(array([3, 2, 4, 1, 0]), array([4, 3, 1, 0, 2], dtype=int64))

In [4]:
def take_indexes(sequences, indexes):
    return torch.gather(sequences, 0, repeat(indexes, 't b -> t b c', c=sequences.shape[-1]))

In [5]:
class PatchShuffle(torch.nn.Module):
    def __init__(self, ratio) -> None:
        super().__init__()
        self.ratio = ratio

    def forward(self, patches : torch.Tensor):
        T, B, C = patches.shape
        remain_T = int(T * (1 - self.ratio))

        indexes = [random_indexes(T) for _ in range(B)]
        forward_indexes = torch.as_tensor(np.stack([i[0] for i in indexes], axis=-1), dtype=torch.long).to(patches.device)
        backward_indexes = torch.as_tensor(np.stack([i[1] for i in indexes], axis=-1), dtype=torch.long).to(patches.device)

        patches = take_indexes(patches, forward_indexes)
        patches = patches[:remain_T]

        return patches, forward_indexes, backward_indexes

In [7]:
class MAE_Encoder(torch.nn.Module):
    def __init__(self,
                 image_size=32,
                 patch_size=2,
                 emb_dim=192,
                 num_layer=12,
                 num_head=3,
                 mask_ratio=0.75,
                 ) -> None:
        super().__init__()

        self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim))
        self.pos_embedding = torch.nn.Parameter(torch.zeros((image_size // patch_size) ** 2, 1, emb_dim))
        self.shuffle = PatchShuffle(mask_ratio)

        self.patchify = torch.nn.Conv2d(3, emb_dim, patch_size, patch_size)

        self.transformer = torch.nn.Sequential(*[Block(emb_dim, num_head) for _ in range(num_layer)])

        self.layer_norm = torch.nn.LayerNorm(emb_dim)

        self.init_weight()
        
    def init_weight(self):
        trunc_normal_(self.cls_token, std=.02)
        trunc_normal_(self.pos_embedding, std=.02)

    def forward(self, img):
        patches = self.patchify(img)
        patches = rearrange(patches, 'b c h w -> (h w) b c')
        patches = patches + self.pos_embedding

        patches, forward_indexes, backward_indexes = self.shuffle(patches)

        patches = torch.cat([self.cls_token.expand(-1, patches.shape[1], -1), patches], dim=0)
        patches = rearrange(patches, 't b c -> b t c')
        features = self.layer_norm(self.transformer(patches))
        features = rearrange(features, 'b t c -> t b c')

        return features, backward_indexes

In [8]:
class MAE_Decoder(torch.nn.Module):
    def __init__(self,
                 image_size=32,
                 patch_size=2,
                 emb_dim=192,
                 num_layer=4,
                 num_head=3,
                 ) -> None:
        super().__init__()

        self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim))
        self.pos_embedding = torch.nn.Parameter(torch.zeros((image_size // patch_size) ** 2 + 1, 1, emb_dim))

        self.transformer = torch.nn.Sequential(*[Block(emb_dim, num_head) for _ in range(num_layer)])

        self.head = torch.nn.Linear(emb_dim, 3 * patch_size ** 2)
        self.patch2img = Rearrange('(h w) b (c p1 p2) -> b c (h p1) (w p2)', p1=patch_size, p2=patch_size, h=image_size//patch_size)

        self.init_weight()

    def init_weight(self):
        trunc_normal_(self.mask_token, std=.02)
        trunc_normal_(self.pos_embedding, std=.02)

    def forward(self, features, backward_indexes):
        T = features.shape[0]
        backward_indexes = torch.cat([torch.zeros(1, backward_indexes.shape[1]).to(backward_indexes), backward_indexes + 1], dim=0)
        features = torch.cat([features, self.mask_token.expand(backward_indexes.shape[0] - features.shape[0], features.shape[1], -1)], dim=0)
        features = take_indexes(features, backward_indexes)
        features = features + self.pos_embedding

        features = rearrange(features, 't b c -> b t c')
        features = self.transformer(features)
        features = rearrange(features, 'b t c -> t b c')
        features = features[1:] # remove global feature

        patches = self.head(features)
        mask = torch.zeros_like(patches)
        mask[T:] = 1
        mask = take_indexes(mask, backward_indexes[1:] - 1)
        img = self.patch2img(patches)
        mask = self.patch2img(mask)

        return img, mask

In [9]:
class MAE_ViT(torch.nn.Module):
    def __init__(self,
                 image_size=32,
                 patch_size=2,
                 emb_dim=192,
                 encoder_layer=12,
                 encoder_head=3,
                 decoder_layer=4,
                 decoder_head=3,
                 mask_ratio=0.75,
                 ) -> None:
        super().__init__()

        self.encoder = MAE_Encoder(image_size, patch_size, emb_dim, encoder_layer, encoder_head, mask_ratio)
        self.decoder = MAE_Decoder(image_size, patch_size, emb_dim, decoder_layer, decoder_head)

    def forward(self, img):
        features, backward_indexes = self.encoder(img)
        predicted_img, mask = self.decoder(features,  backward_indexes)
        return predicted_img, mask

In [10]:
class ViT_Classifier(torch.nn.Module):
    def __init__(self, encoder : MAE_Encoder, num_classes=10) -> None:
        super().__init__()
        self.cls_token = encoder.cls_token
        self.pos_embedding = encoder.pos_embedding
        self.patchify = encoder.patchify
        self.transformer = encoder.transformer
        self.layer_norm = encoder.layer_norm
        self.head = torch.nn.Linear(self.pos_embedding.shape[-1], num_classes)

    def forward(self, img):
        patches = self.patchify(img)
        patches = rearrange(patches, 'b c h w -> (h w) b c')
        patches = patches + self.pos_embedding
        patches = torch.cat([self.cls_token.expand(-1, patches.shape[1], -1), patches], dim=0)
        patches = rearrange(patches, 't b c -> b t c')
        features = self.layer_norm(self.transformer(patches))
        features = rearrange(features, 'b t c -> t b c')
        logits = self.head(features[0])
        return logits


In [11]:
shuffle = PatchShuffle(0.75)
a = torch.rand(16, 2, 10)
b, forward_indexes, backward_indexes = shuffle(a)
print(b.shape)

img = torch.rand(2, 3, 32, 32)
encoder = MAE_Encoder()
decoder = MAE_Decoder()
features, backward_indexes = encoder(img)
print(forward_indexes.shape)
predicted_img, mask = decoder(features, backward_indexes)
print(predicted_img.shape)
loss = torch.mean((predicted_img - img) ** 2 * mask / 0.75)
print(loss)

torch.Size([4, 2, 10])
torch.Size([16, 2])
torch.Size([2, 3, 32, 32])
tensor(0.4049, grad_fn=<MeanBackward0>)


In [12]:
import argparse
from tqdm import tqdm
import torchvision
from torchvision.transforms import ToTensor, Compose, Normalize
import math 
import cv2


parser = argparse.ArgumentParser()
seed = 42
batch_size = 256
max_device_batch_size = 512
base_learning_rate = 1.5e-4
weight_decay =0.05
mask_ratio =0.75
total_epoch = 10
warmup_epoch = 2
# model_path='vit-t-mae.pt'



batch_size = batch_size
load_batch_size = min(max_device_batch_size, batch_size)

assert batch_size % load_batch_size == 0
steps_per_update = batch_size // load_batch_size

train_dataset = torchvision.datasets.CIFAR10('data', train=True, download=True, transform=Compose([ToTensor(), Normalize(0.5, 0.5)]))
val_dataset = torchvision.datasets.CIFAR10('data', train=False, download=True, transform=Compose([ToTensor(), Normalize(0.5, 0.5)]))
dataloader = torch.utils.data.DataLoader(train_dataset, load_batch_size, shuffle=True, num_workers=4)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = MAE_ViT(mask_ratio=mask_ratio).to(device)
optim = torch.optim.AdamW(model.parameters(), lr=base_learning_rate * batch_size / 256, betas=(0.9, 0.95), weight_decay=weight_decay)
lr_func = lambda epoch: min((epoch + 1) / (warmup_epoch + 1e-8), 0.5 * (math.cos(epoch / total_epoch * math.pi) + 1))
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=lr_func, verbose=True)

step_count = 0
optim.zero_grad()
for e in range(total_epoch):
    model.train()
    losses = []
    for img, label in tqdm(iter(dataloader)):
        step_count += 1
        img = img.to(device)
        predicted_img, mask = model(img)
        loss = torch.mean((predicted_img - img) ** 2 * mask) / mask_ratio
        loss.backward()
        if step_count % steps_per_update == 0:
            optim.step()
            optim.zero_grad()
        losses.append(loss.item())
    lr_scheduler.step()
    avg_loss = sum(losses) / len(losses)
    print(f'In epoch {e}, average traning loss is {avg_loss}.')

    ''' visualize the first 16 predicted images on val dataset'''
    if (e % 50 == 0) or (e == total_epoch - 1):
        model.eval()
        with torch.no_grad():
            val_img = torch.stack([val_dataset[i][0] for i in range(16)])
            val_img = val_img.to(device)
            predicted_val_img, mask = model(val_img)
            predicted_val_img = predicted_val_img * mask + val_img * (1 - mask)
            img = torch.cat([val_img * (1 - mask), predicted_val_img, val_img], dim=0)
            img = rearrange(img, '(v h1 w1) c h w -> c (h1 h) (w1 v w)', w1=2, v=3)
            cv2.imwrite('mae_image_' + f'{e}.jpg', (((img.permute(1, 2, 0) + 1)/2)*255).detach().cpu().numpy())
        
        ''' save model '''
        torch.save(model, model_path)

Files already downloaded and verified
Files already downloaded and verified
Adjusting learning rate of group 0 to 7.5000e-05.


  5%|▍         | 9/196 [8:37:32<71:31:00, 1376.79s/it]  

# Training the Classifier

In [12]:
import os
import argparse
import math
import torch
import torchvision
from torchvision.transforms import ToTensor, Compose, Normalize
from tqdm import tqdm


seed = 42
batch_size = 128
max_device_batch_size = 256
base_learning_rate = 1e-3
weight_decay = 0.05
total_epoch = 10
warmup_epoch = 2
pretrained_model_path = 'vit-t-mae.pt'
output_model_path ='vit-t-classifier-from_pretrained.pt'


batch_size = batch_size
load_batch_size = min(max_device_batch_size, batch_size)

assert batch_size % load_batch_size == 0
steps_per_update = batch_size // load_batch_size

train_dataset = torchvision.datasets.CIFAR10('data', train=True, download=True, transform=Compose([ToTensor(), Normalize(0.5, 0.5)]))
val_dataset = torchvision.datasets.CIFAR10('data', train=False, download=True, transform=Compose([ToTensor(), Normalize(0.5, 0.5)]))
train_dataloader = torch.utils.data.DataLoader(train_dataset, load_batch_size, shuffle=True, num_workers=4)
val_dataloader = torch.utils.data.DataLoader(val_dataset, load_batch_size, shuffle=False, num_workers=4)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

if pretrained_model_path is not None:
    model = torch.load(pretrained_model_path, map_location='cpu')
else:
    model = MAE_ViT()
model = ViT_Classifier(model.encoder, num_classes=10).to(device)



Files already downloaded and verified
Files already downloaded and verified


In [13]:
loss_fn = torch.nn.CrossEntropyLoss()
acc_fn = lambda logit, label: torch.mean((logit.argmax(dim=-1) == label).float())

optim = torch.optim.AdamW(model.parameters(), lr=base_learning_rate * batch_size / 256, betas=(0.9, 0.999), weight_decay=weight_decay)
lr_func = lambda epoch: min((epoch + 1) / (warmup_epoch + 1e-8), 0.5 * (math.cos(epoch / total_epoch * math.pi) + 1))
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=lr_func, verbose=True)

Adjusting learning rate of group 0 to 1.0000e-04.


In [14]:


best_val_acc = 0
step_count = 0
optim.zero_grad()
for e in range(total_epoch):
    model.train()
    losses = []
    acces = []
    for img, label in tqdm(iter(train_dataloader)):
        step_count += 1
        img = img.to(device)
        label = label.to(device)
        logits = model(img)
        loss = loss_fn(logits, label)
        acc = acc_fn(logits, label)
        loss.backward()
        if step_count % steps_per_update == 0:
            optim.step()
            optim.zero_grad()
        losses.append(loss.item())
        acces.append(acc.item())
    lr_scheduler.step()
    avg_train_loss = sum(losses) / len(losses)
    avg_train_acc = sum(acces) / len(acces)
    print(f'In epoch {e}, average training loss is {avg_train_loss}, average training acc is {avg_train_acc}.')

    model.eval()
    with torch.no_grad():
        losses = []
        acces = []
        for img, label in tqdm(iter(val_dataloader)):
            img = img.to(device)
            label = label.to(device)
            logits = model(img)
            loss = loss_fn(logits, label)
            acc = acc_fn(logits, label)
            losses.append(loss.item())
            acces.append(acc.item())
        avg_val_loss = sum(losses) / len(losses)
        avg_val_acc = sum(acces) / len(acces)
        print(f'In epoch {e}, average validation loss is {avg_val_loss}, average validation acc is {avg_val_acc}.')  

    if avg_val_acc > best_val_acc:
        best_val_acc = avg_val_acc
        print(f'saving best model with acc {best_val_acc} at {e} epoch!')       
        torch.save(model, output_model_path)

100%|██████████| 391/391 [01:41<00:00,  3.87it/s]


Adjusting learning rate of group 0 to 2.0000e-04.
In epoch 0, average training loss is 0.8092588828805157, average training acc is 0.749932065186903.


100%|██████████| 79/79 [00:06<00:00, 12.38it/s]


In epoch 0, average validation loss is 0.5546492560754849, average validation acc is 0.8222903481012658.
saving best model with acc 0.8222903481012658 at 0 epoch!


100%|██████████| 391/391 [01:43<00:00,  3.79it/s]


Adjusting learning rate of group 0 to 3.0000e-04.
In epoch 1, average training loss is 0.4715146217733393, average training acc is 0.8454483695652174.


100%|██████████| 79/79 [00:06<00:00, 12.56it/s]


In epoch 1, average validation loss is 0.4984543546091152, average validation acc is 0.834256329113924.
saving best model with acc 0.834256329113924 at 1 epoch!


100%|██████████| 391/391 [01:45<00:00,  3.70it/s]


Adjusting learning rate of group 0 to 4.0000e-04.
In epoch 2, average training loss is 0.3823876475053065, average training acc is 0.8727701406954499.


100%|██████████| 79/79 [00:06<00:00, 12.46it/s]


In epoch 2, average validation loss is 0.4729167155072659, average validation acc is 0.840684335443038.
saving best model with acc 0.840684335443038 at 2 epoch!


100%|██████████| 391/391 [01:46<00:00,  3.68it/s]


Adjusting learning rate of group 0 to 4.9803e-04.
In epoch 3, average training loss is 0.33977085889300424, average training acc is 0.885246163743841.


100%|██████████| 79/79 [00:06<00:00, 12.54it/s]


In epoch 3, average validation loss is 0.48131733472588695, average validation acc is 0.8379153481012658.


100%|██████████| 391/391 [01:47<00:00,  3.63it/s]


Adjusting learning rate of group 0 to 4.9692e-04.
In epoch 4, average training loss is 0.30732533465261047, average training acc is 0.8955522697904835.


100%|██████████| 79/79 [00:06<00:00, 12.22it/s]


In epoch 4, average validation loss is 0.4718705702431594, average validation acc is 0.8449367088607594.
saving best model with acc 0.8449367088607594 at 4 epoch!


100%|██████████| 391/391 [01:47<00:00,  3.62it/s]


Adjusting learning rate of group 0 to 4.9557e-04.
In epoch 5, average training loss is 0.25423051270148944, average training acc is 0.914122442455243.


100%|██████████| 79/79 [00:06<00:00, 12.51it/s]


In epoch 5, average validation loss is 0.48160936372189583, average validation acc is 0.8443433544303798.


100%|██████████| 391/391 [01:47<00:00,  3.62it/s]


Adjusting learning rate of group 0 to 4.9398e-04.
In epoch 6, average training loss is 0.2108131613763397, average training acc is 0.9272858056875751.


100%|██████████| 79/79 [00:06<00:00, 12.35it/s]


In epoch 6, average validation loss is 0.47957642583907406, average validation acc is 0.8446400316455697.


100%|██████████| 391/391 [01:48<00:00,  3.60it/s]


Adjusting learning rate of group 0 to 4.9215e-04.
In epoch 7, average training loss is 0.18042440136985097, average training acc is 0.9387468030995421.


100%|██████████| 79/79 [00:06<00:00, 12.40it/s]


In epoch 7, average validation loss is 0.4901231825351715, average validation acc is 0.8485957278481012.
saving best model with acc 0.8485957278481012 at 7 epoch!


100%|██████████| 391/391 [01:48<00:00,  3.60it/s]


Adjusting learning rate of group 0 to 4.9007e-04.
In epoch 8, average training loss is 0.15164572181527877, average training acc is 0.9480498721227621.


100%|██████████| 79/79 [00:06<00:00, 12.53it/s]


In epoch 8, average validation loss is 0.4912558986416346, average validation acc is 0.8559137658227848.
saving best model with acc 0.8559137658227848 at 8 epoch!


100%|██████████| 391/391 [01:49<00:00,  3.57it/s]


Adjusting learning rate of group 0 to 4.8776e-04.
In epoch 9, average training loss is 0.13950247662927945, average training acc is 0.9512827686031761.


100%|██████████| 79/79 [00:06<00:00, 12.38it/s]


In epoch 9, average validation loss is 0.4892511794084235, average validation acc is 0.8582871835443038.
saving best model with acc 0.8582871835443038 at 9 epoch!


100%|██████████| 391/391 [01:48<00:00,  3.60it/s]


Adjusting learning rate of group 0 to 4.8522e-04.
In epoch 10, average training loss is 0.1161088297319839, average training acc is 0.9597106777188723.


100%|██████████| 79/79 [00:06<00:00, 12.38it/s]


In epoch 10, average validation loss is 0.5005949692258352, average validation acc is 0.857001582278481.


100%|██████████| 391/391 [01:48<00:00,  3.60it/s]


Adjusting learning rate of group 0 to 4.8244e-04.
In epoch 11, average training loss is 0.11460668337352746, average training acc is 0.9612372122762148.


100%|██████████| 79/79 [00:06<00:00, 12.74it/s]


In epoch 11, average validation loss is 0.5008450482465043, average validation acc is 0.8605617088607594.
saving best model with acc 0.8605617088607594 at 11 epoch!


100%|██████████| 391/391 [01:48<00:00,  3.59it/s]


Adjusting learning rate of group 0 to 4.7944e-04.
In epoch 12, average training loss is 0.10881925941637867, average training acc is 0.9621443414627133.


100%|██████████| 79/79 [00:06<00:00, 12.69it/s]


In epoch 12, average validation loss is 0.503437374966054, average validation acc is 0.8592761075949367.


100%|██████████| 391/391 [01:49<00:00,  3.58it/s]


Adjusting learning rate of group 0 to 4.7621e-04.
In epoch 13, average training loss is 0.09473688315952677, average training acc is 0.9675231777188723.


100%|██████████| 79/79 [00:06<00:00, 12.51it/s]


In epoch 13, average validation loss is 0.5620209970806218, average validation acc is 0.8441455696202531.


100%|██████████| 391/391 [01:49<00:00,  3.57it/s]


Adjusting learning rate of group 0 to 4.7275e-04.
In epoch 14, average training loss is 0.08537368332762318, average training acc is 0.9709239131044549.


100%|██████████| 79/79 [00:06<00:00, 12.37it/s]


In epoch 14, average validation loss is 0.5909798284874687, average validation acc is 0.8423655063291139.


100%|██████████| 391/391 [01:49<00:00,  3.57it/s]


Adjusting learning rate of group 0 to 4.6908e-04.
In epoch 15, average training loss is 0.08503256307538513, average training acc is 0.9707800511204069.


100%|██████████| 79/79 [00:06<00:00, 12.32it/s]


In epoch 15, average validation loss is 0.521845360156856, average validation acc is 0.8626384493670886.
saving best model with acc 0.8626384493670886 at 15 epoch!


100%|██████████| 391/391 [01:50<00:00,  3.54it/s]


Adjusting learning rate of group 0 to 4.6519e-04.
In epoch 16, average training loss is 0.08151754396526939, average training acc is 0.9715752877542735.


100%|██████████| 79/79 [00:06<00:00, 12.38it/s]


In epoch 16, average validation loss is 0.5232162524627734, average validation acc is 0.8558148734177216.


100%|██████████| 391/391 [01:49<00:00,  3.56it/s]


Adjusting learning rate of group 0 to 4.6108e-04.
In epoch 17, average training loss is 0.07273908439890274, average training acc is 0.9755155050846012.


100%|██████████| 79/79 [00:06<00:00, 12.59it/s]


In epoch 17, average validation loss is 0.5645346075673646, average validation acc is 0.8512658227848101.


100%|██████████| 391/391 [01:48<00:00,  3.59it/s]


Adjusting learning rate of group 0 to 4.5677e-04.
In epoch 18, average training loss is 0.07176914598668932, average training acc is 0.9756473786080889.


100%|██████████| 79/79 [00:06<00:00, 12.57it/s]


In epoch 18, average validation loss is 0.5463785500843313, average validation acc is 0.861748417721519.


100%|██████████| 391/391 [01:48<00:00,  3.60it/s]


Adjusting learning rate of group 0 to 4.5225e-04.
In epoch 19, average training loss is 0.0643666934977045, average training acc is 0.9779171995494677.


100%|██████████| 79/79 [00:06<00:00, 12.34it/s]


In epoch 19, average validation loss is 0.5443740796439255, average validation acc is 0.8612539556962026.


100%|██████████| 391/391 [01:48<00:00,  3.60it/s]


Adjusting learning rate of group 0 to 4.4754e-04.
In epoch 20, average training loss is 0.06376871205342319, average training acc is 0.9776414641943734.


100%|██████████| 79/79 [00:06<00:00, 12.76it/s]


In epoch 20, average validation loss is 0.6261422279514844, average validation acc is 0.8467167721518988.


100%|██████████| 391/391 [01:49<00:00,  3.57it/s]


Adjusting learning rate of group 0 to 4.4263e-04.
In epoch 21, average training loss is 0.06677096930172895, average training acc is 0.9767862852577054.


100%|██████████| 79/79 [00:06<00:00, 12.96it/s]


In epoch 21, average validation loss is 0.584846763671199, average validation acc is 0.8527492088607594.


100%|██████████| 391/391 [01:46<00:00,  3.67it/s]


Adjusting learning rate of group 0 to 4.3753e-04.
In epoch 22, average training loss is 0.05644998595456752, average training acc is 0.9811940537694165.


100%|██████████| 79/79 [00:06<00:00, 13.05it/s]


In epoch 22, average validation loss is 0.5825380577316767, average validation acc is 0.8476068037974683.


100%|██████████| 391/391 [01:47<00:00,  3.65it/s]


Adjusting learning rate of group 0 to 4.3224e-04.
In epoch 23, average training loss is 0.05097952040503054, average training acc is 0.9818853900560638.


100%|██████████| 79/79 [00:06<00:00, 12.90it/s]


In epoch 23, average validation loss is 0.5979391211950327, average validation acc is 0.8542325949367089.


100%|██████████| 391/391 [01:45<00:00,  3.70it/s]


Adjusting learning rate of group 0 to 4.2678e-04.
In epoch 24, average training loss is 0.05708466205612549, average training acc is 0.9801310741383097.


100%|██████████| 79/79 [00:06<00:00, 12.82it/s]


In epoch 24, average validation loss is 0.6013690927360631, average validation acc is 0.8553204113924051.


100%|██████████| 391/391 [01:45<00:00,  3.71it/s]


Adjusting learning rate of group 0 to 4.2114e-04.
In epoch 25, average training loss is 0.053713259125804845, average training acc is 0.9818214514981145.


100%|██████████| 79/79 [00:06<00:00, 12.91it/s]


In epoch 25, average validation loss is 0.5990769021873232, average validation acc is 0.8549248417721519.


100%|██████████| 391/391 [01:45<00:00,  3.70it/s]


Adjusting learning rate of group 0 to 4.1533e-04.
In epoch 26, average training loss is 0.048606577087574834, average training acc is 0.9838515026185214.


100%|██████████| 79/79 [00:06<00:00, 12.82it/s]


In epoch 26, average validation loss is 0.5582658720167377, average validation acc is 0.8620450949367089.


100%|██████████| 391/391 [01:45<00:00,  3.70it/s]


Adjusting learning rate of group 0 to 4.0936e-04.
In epoch 27, average training loss is 0.051527983944891664, average training acc is 0.9821771099744245.


100%|██████████| 79/79 [00:06<00:00, 12.93it/s]


In epoch 27, average validation loss is 0.5628088478800617, average validation acc is 0.8637262658227848.
saving best model with acc 0.8637262658227848 at 27 epoch!


100%|██████████| 391/391 [01:45<00:00,  3.70it/s]


Adjusting learning rate of group 0 to 4.0323e-04.
In epoch 28, average training loss is 0.046747243334956064, average training acc is 0.9842750959079284.


100%|██████████| 79/79 [00:06<00:00, 13.00it/s]


In epoch 28, average validation loss is 0.5547978240477888, average validation acc is 0.861748417721519.


100%|██████████| 391/391 [01:46<00:00,  3.66it/s]


Adjusting learning rate of group 0 to 3.9695e-04.
In epoch 29, average training loss is 0.04064882244757565, average training acc is 0.9859015345573425.


100%|██████████| 79/79 [00:06<00:00, 12.82it/s]


In epoch 29, average validation loss is 0.6186274806909924, average validation acc is 0.8601661392405063.


100%|██████████| 391/391 [01:46<00:00,  3.67it/s]


Adjusting learning rate of group 0 to 3.9052e-04.
In epoch 30, average training loss is 0.0400939331475712, average training acc is 0.9857696611862963.


100%|██████████| 79/79 [00:06<00:00, 12.64it/s]


In epoch 30, average validation loss is 0.6476732284962377, average validation acc is 0.8517602848101266.


100%|██████████| 391/391 [01:46<00:00,  3.68it/s]


Adjusting learning rate of group 0 to 3.8396e-04.
In epoch 31, average training loss is 0.040834049848468064, average training acc is 0.9862731777493606.


100%|██████████| 79/79 [00:06<00:00, 12.64it/s]


In epoch 31, average validation loss is 0.5751183393258082, average validation acc is 0.8644185126582279.
saving best model with acc 0.8644185126582279 at 31 epoch!


100%|██████████| 391/391 [01:46<00:00,  3.68it/s]


Adjusting learning rate of group 0 to 3.7726e-04.
In epoch 32, average training loss is 0.04018619831989679, average training acc is 0.9860174233651222.


100%|██████████| 79/79 [00:06<00:00, 12.70it/s]


In epoch 32, average validation loss is 0.5866644812535636, average validation acc is 0.8628362341772152.


100%|██████████| 391/391 [01:46<00:00,  3.69it/s]


Adjusting learning rate of group 0 to 3.7044e-04.
In epoch 33, average training loss is 0.03711950368411086, average training acc is 0.9868326406649617.


100%|██████████| 79/79 [00:06<00:00, 12.54it/s]


In epoch 33, average validation loss is 0.5789365078075023, average validation acc is 0.8631329113924051.


100%|██████████| 391/391 [01:46<00:00,  3.68it/s]


Adjusting learning rate of group 0 to 3.6350e-04.
In epoch 34, average training loss is 0.03259238757459861, average training acc is 0.9890864770430738.


100%|██████████| 79/79 [00:06<00:00, 12.82it/s]


In epoch 34, average validation loss is 0.5854387224852284, average validation acc is 0.8663963607594937.
saving best model with acc 0.8663963607594937 at 34 epoch!


100%|██████████| 391/391 [01:46<00:00,  3.68it/s]


Adjusting learning rate of group 0 to 3.5644e-04.
In epoch 35, average training loss is 0.030739561260125274, average training acc is 0.9895580243271636.


100%|██████████| 79/79 [00:06<00:00, 12.76it/s]


In epoch 35, average validation loss is 0.5823520010785211, average validation acc is 0.8697587025316456.
saving best model with acc 0.8697587025316456 at 35 epoch!


100%|██████████| 391/391 [01:48<00:00,  3.59it/s]


Adjusting learning rate of group 0 to 3.4929e-04.
In epoch 36, average training loss is 0.03310428506246937, average training acc is 0.9886468990379588.


100%|██████████| 79/79 [00:06<00:00, 12.85it/s]


In epoch 36, average validation loss is 0.6676604329030725, average validation acc is 0.8506724683544303.


100%|██████████| 391/391 [01:47<00:00,  3.64it/s]


Adjusting learning rate of group 0 to 3.4203e-04.
In epoch 37, average training loss is 0.03147744701441635, average training acc is 0.9891783887772914.


100%|██████████| 79/79 [00:06<00:00, 12.66it/s]


In epoch 37, average validation loss is 0.5992068904864637, average validation acc is 0.8674841772151899.


100%|██████████| 391/391 [01:45<00:00,  3.69it/s]


Adjusting learning rate of group 0 to 3.3468e-04.
In epoch 38, average training loss is 0.026480471644469577, average training acc is 0.990948689258312.


100%|██████████| 79/79 [00:06<00:00, 12.82it/s]


In epoch 38, average validation loss is 0.579425045206577, average validation acc is 0.8694620253164557.


100%|██████████| 391/391 [01:46<00:00,  3.69it/s]


Adjusting learning rate of group 0 to 3.2725e-04.
In epoch 39, average training loss is 0.026454390139471443, average training acc is 0.9911764706187236.


100%|██████████| 79/79 [00:06<00:00, 12.63it/s]


In epoch 39, average validation loss is 0.5703269058390509, average validation acc is 0.8725276898734177.
saving best model with acc 0.8725276898734177 at 39 epoch!


100%|██████████| 391/391 [01:45<00:00,  3.69it/s]


Adjusting learning rate of group 0 to 3.1975e-04.
In epoch 40, average training loss is 0.028274875346814162, average training acc is 0.9902373721532505.


100%|██████████| 79/79 [00:06<00:00, 12.83it/s]


In epoch 40, average validation loss is 0.5923376984988586, average validation acc is 0.8668908227848101.


100%|██████████| 391/391 [01:45<00:00,  3.69it/s]


Adjusting learning rate of group 0 to 3.1217e-04.
In epoch 41, average training loss is 0.02531788310677985, average training acc is 0.9916080562659847.


100%|██████████| 79/79 [00:06<00:00, 12.96it/s]


In epoch 41, average validation loss is 0.6034003292457967, average validation acc is 0.8685719936708861.


100%|██████████| 391/391 [01:46<00:00,  3.67it/s]


Adjusting learning rate of group 0 to 3.0454e-04.
In epoch 42, average training loss is 0.02403855361633927, average training acc is 0.9919956841737109.


100%|██████████| 79/79 [00:06<00:00, 12.76it/s]


In epoch 42, average validation loss is 0.6420080673091019, average validation acc is 0.8614517405063291.


100%|██████████| 391/391 [01:45<00:00,  3.71it/s]


Adjusting learning rate of group 0 to 2.9685e-04.
In epoch 43, average training loss is 0.022164296506650392, average training acc is 0.992746962915601.


100%|██████████| 79/79 [00:06<00:00, 13.07it/s]


In epoch 43, average validation loss is 0.6420605209054826, average validation acc is 0.8658030063291139.


100%|██████████| 391/391 [01:45<00:00,  3.70it/s]


Adjusting learning rate of group 0 to 2.8911e-04.
In epoch 44, average training loss is 0.02092304814126357, average training acc is 0.9930466751918159.


100%|██████████| 79/79 [00:06<00:00, 12.81it/s]


In epoch 44, average validation loss is 0.5869935075693493, average validation acc is 0.8750988924050633.
saving best model with acc 0.8750988924050633 at 44 epoch!


100%|██████████| 391/391 [01:45<00:00,  3.70it/s]


Adjusting learning rate of group 0 to 2.8133e-04.
In epoch 45, average training loss is 0.01815817006738604, average training acc is 0.993985773657289.


100%|██████████| 79/79 [00:06<00:00, 12.80it/s]


In epoch 45, average validation loss is 0.6135237550810922, average validation acc is 0.8661985759493671.


100%|██████████| 391/391 [01:46<00:00,  3.68it/s]


Adjusting learning rate of group 0 to 2.7353e-04.
In epoch 46, average training loss is 0.020812634517566143, average training acc is 0.992746962915601.


100%|██████████| 79/79 [00:06<00:00, 12.89it/s]


In epoch 46, average validation loss is 0.6670222886001007, average validation acc is 0.8595727848101266.


100%|██████████| 391/391 [01:46<00:00,  3.68it/s]


Adjusting learning rate of group 0 to 2.6570e-04.
In epoch 47, average training loss is 0.01852897750100602, average training acc is 0.9939138427414858.


100%|██████████| 79/79 [00:06<00:00, 12.84it/s]


In epoch 47, average validation loss is 0.6099797973904428, average validation acc is 0.8643196202531646.


 59%|█████▉    | 231/391 [01:08<00:47,  3.36it/s]


RuntimeError: CUDA error: unknown error
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.