In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import os
import sys
import copy
import numpy as np
import random
import multiprocessing
from tqdm import tqdm

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [31]:
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

In [None]:
import torchvision
import torchvision.transforms as transforms

## Image Data

In [None]:
data_path = './drive/MyDrive/Model/Data/ImageNet/'

In [None]:
def read_data(dir_path, size) :
    img_data = []
    label_data = []
    for i in tqdm(range(1, size+1)) :
        file_path = dir_path + 'train_data_batch_'+str(i)
        data_dict = np.load(file_path, allow_pickle=True)

        img_idx = data_dict['data']
        img_idx = img_idx.reshape(-1, 64, 64, 3)
        img_idx = np.transpose(img_idx , (0, 3, 1, 2))
        label_idx = data_dict['labels']

        img_data.append(img_idx)
        label_data.extend(label_idx)

    img_data = np.vstack(img_data)
    label_data = np.array(label_data)

    return img_data, label_data        

In [None]:
train_image, train_label = read_data(data_path, 3)

100%|██████████| 3/3 [01:04<00:00, 21.53s/it]


In [None]:
print('Train Data Shape \n') 

print('Image Shape : {}'.format(train_image.shape))
print('Label Shape : {}'.format(train_label.shape))

Train Data Shape 

Image Shape : (384348, 3, 64, 64)
Label Shape : (384348,)


In [None]:
val_data = np.load(data_path+'val_data' , allow_pickle=True)

In [None]:
val_image = val_data['data'].reshape(-1, 64, 64, 3)
val_image = np.transpose(val_image, (0, 3, 1, 2))
val_label = np.array(val_data['labels'])

In [None]:
print('Validation Data Shape \n') 

print('Image Shape : {}'.format(val_image.shape))
print('Label Shape : {}'.format(val_label.shape))

Validation Data Shape 

Image Shape : (50000, 3, 64, 64)
Label Shape : (50000,)


## Dataset & Dataloader

In [13]:
class ImageDataset(Dataset) :
    def __init__(self , data , label, class_size) :
        super(ImageDataset , self).__init__()
        self.data = data
        self.label = np.eye(class_size)[label-1]

    def __len__(self) :
        data_len = self.data.shape[0]
        return data_len

    def __getitem__(self , idx) :
        img_data = self.data[idx]
        img_label = self.label[idx]
        return img_data , img_label

class TrainTransforms :
    def __init__(self, org_size, tar_size) :
        self.org_size = org_size
        self.tar_size = tar_size
        self.transform = transforms.Compose([
            transforms.Resize(org_size),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomCrop(tar_size),
        ])

    def __call__(self, img_tensor) :
        return self.transform(img_tensor)

class ValTransforms :
    def __init__(self, tar_size) :
        self.tar_size = tar_size
        self.transform = transforms.Compose([
            transforms.Resize(tar_size),
            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.2, 0.2, 0.2))
        ])

    def __call__(self, img_tensor) :
        return self.transform(img_tensor)

In [14]:
class CutMix :
    def __init__(self, img_height, img_width):
        self.h = img_height
        self.w = img_width 
        self.gen = torch.distributions.beta.Beta(1,1)
        
    def __call__(self, a_image, a_label):
        batch_size = a_image.shape[0]
        rand = torch.randperm(batch_size)
        b_image = a_image[rand]
        b_label = a_label[rand]
        
        y = torch.randint(self.h, (1,))[0]
        x = torch.randint(self.w, (1,))[0]

        r = self.gen.sample()
        h = (self.h * torch.sqrt(1-r)).int()
        w = (self.w * torch.sqrt(1-r)).int()
        c_image = copy.deepcopy(a_image)
        c_image[: , : , y:y+h , x:x+w] = b_image[: , : ,y:y+h , x:x+w]

        c_label = a_label * r + b_label * (1-r)
        return c_image, c_label

In [15]:
class_size = 1000
batch_size = 128
org_size = 256
img_size = 224

In [16]:
train_dset = ImageDataset(train_image, train_label, class_size)
train_loader = DataLoader(train_dset, 
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=multiprocessing.cpu_count()//2)

In [17]:
val_dset = ImageDataset(val_image, val_label, class_size)
val_loader = DataLoader(val_dset, 
                        batch_size=batch_size,
                        shuffle=True,
                        num_workers=multiprocessing.cpu_count()//2)

In [18]:
train_transform = TrainTransforms(org_size, img_size)
norm_transform = transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.2, 0.2, 0.2))
val_transform = ValTransforms(img_size)

img_cutmix = CutMix(img_size, img_size)

## Device & Seed

In [19]:
def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)

seed_everything(777)
use_cuda =  torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu") 

## Model

In [20]:
class PatchFlatten(nn.Module) :
    def __init__(self, ch_size, img_size, p_size) :
        super(PatchFlatten , self).__init__()
        assert img_size % p_size == 0
        self.ch_size = ch_size
        self.img_size = img_size
        self.p_size = p_size

    def forward(self, img_tensor) :
        p_len = int(self.img_size / self.p_size)
        v_size = self.p_size ** 2 * self.ch_size

        img_tensor = img_tensor.permute(0,2,3,1)
        img_patchs = torch.reshape(img_tensor , (-1, p_len, self.p_size, p_len, self.p_size, self.ch_size))
        patch_tensor = img_patchs.permute(0,1,3,2,4,5)

        v_tensor = torch.reshape(patch_tensor , (-1, p_len, p_len, v_size))
        v_tensor = torch.reshape(v_tensor , (-1, p_len**2, v_size))
        return v_tensor

class PositionEmbedding(nn.Module) :
    def __init__(self, p_len, v_size, em_dim, cuda_flag) :
        super(PositionEmbedding , self).__init__()
        self.p_len = p_len
        self.v_size = v_size
        self.em_dim = em_dim
        # start token 
        cls_tensor = torch.FloatTensor(np.random.randn(1,em_dim))
        if cuda_flag :
            cls_tensor = cls_tensor.cuda()
        self.cls_tensor = nn.Parameter(cls_tensor, requires_grad=True)
        # positional encoding tensor which is trainable
        pos_tensor = torch.FloatTensor(np.random.randn(1 , p_len+1 , em_dim))
        if cuda_flag :
            pos_tensor = pos_tensor.cuda()
        self.pos_tensor = nn.Parameter(pos_tensor , requires_grad=True) #(1 , patch_len+1 , em_dim)
        self.pos_linear = nn.Linear(v_size , em_dim)

    def forward(self, f_tensor) :
        batch_size = f_tensor.shape[0]
        # repeat cls tensor
        cls_tensor = self.cls_tensor.repeat(batch_size , 1) #(batch_size , em_dim)
        cls_tensor = cls_tensor.unsqueeze(1) #(batch_size , 1 , em_dim)
        # apply linear layer , convert vector shape patch vector size to embedding dimension
        x_tensor = self.pos_linear(f_tensor) #(batch_size , patch_len , em_dim)
        x_tensor = torch.cat([cls_tensor , x_tensor] , dim=1) #(batch_size , patch_len + 1 , em_dim)
        # add positional encoding to vector
        z_tensor = x_tensor + self.pos_tensor #(batch_size , patch_len+1 , em_dim)
        return z_tensor

class MultiHeadAttention(nn.Module) :
    def __init__(self, sen_size,  d_model, num_heads) :
        super(MultiHeadAttention , self).__init__()
        self.sen_size = sen_size
        self.d_model = d_model
        self.num_heads = num_heads
        self.depth = int(d_model / num_heads) # embedding_dim / num_heads

        self.q_layer = nn.Linear(d_model , d_model)
        self.k_layer = nn.Linear(d_model , d_model)
        self.v_layer = nn.Linear(d_model , d_model)
        self.o_layer = nn.Linear(d_model , d_model)

        self.scale = torch.sqrt(torch.tensor(self.depth , dtype=torch.float32 , requires_grad=False))

    def split(self , tensor) :
        tensor = torch.reshape(tensor , (-1 , self.sen_size , self.num_heads , self.depth)) # (batch_size , sen_size , num_heads , depth)
        tensor = torch.transpose(tensor , 1 , 2) # batch_size , num_heads , sen_size , depth)
        return tensor

    def merge(self , tensor) :
        tensor = torch.transpose(tensor , 1 , 2) # (batch_size , sen_size , num_heads , depth)
        tensor = torch.reshape(tensor , (-1 , self.sen_size , self.d_model)) # (batch_size , sen_size , embedding_dim)
        return tensor

    def scaled_dot_production(self, q_tensor, k_tensor, v_tensor, m_tensor) :
        q_tensor = self.split(q_tensor)
        k_tensor = self.split(k_tensor)
        v_tensor = self.split(v_tensor)
        k_tensor_T = torch.transpose(k_tensor , 2 , 3) # (batch_size , num_heads , depth , sen_size)

        qk_tensor = torch.matmul(q_tensor , k_tensor_T) # (batch_size , num_heads , sen_size , sen_size)
        qk_tensor /= self.scale
        if m_tensor != None :
            qk_tensor -= (m_tensor * 1e+6)

        qk_tensor = F.softmax(qk_tensor , dim = -1)
        att = torch.matmul(qk_tensor , v_tensor) # (batch_size , num_heads , sen_size , depth)
        return att

    def forward(self, q_in, k_in, v_in, m_in) :
        q_tensor = self.q_layer(q_in)
        k_tensor = self.k_layer(k_in)
        v_tensor = self.v_layer(v_in)

        att_tensor = self.scaled_dot_production(q_tensor, k_tensor, v_tensor, m_in)
        att_tensor = self.merge(att_tensor)

        o_tensor = self.o_layer(att_tensor)
        return o_tensor

class FeedForward(nn.Module) :
    def __init__(self, hidden_size, d_model) :
        super(FeedForward , self).__init__()
        self.hidden_size = hidden_size
        self.d_model = d_model
        self.ff = nn.Sequential(nn.Linear(d_model , hidden_size), 
                                nn.ReLU(),
                                nn.Linear(hidden_size , d_model))

    def forward(self, in_tensor) :
        o_tensor = self.ff(in_tensor)
        return o_tensor


class EncoderBlock(nn.Module) :
    def __init__(self, sen_size, d_model, num_heads, hidden_size, drop_rate, norm_rate) :
        super(EncoderBlock , self).__init__()
        self.sen_size = sen_size
        self.d_model = d_model
        self.num_heads = num_heads
        self.hidden_size = hidden_size

        self.mha_layer = MultiHeadAttention(sen_size , d_model , num_heads)
        self.ff_layer = FeedForward(hidden_size , d_model)

        self.drop1_layer = nn.Dropout(drop_rate)
        self.norm1_layer = nn.LayerNorm(d_model , eps=norm_rate)
        self.drop2_layer = nn.Dropout(drop_rate)
        self.norm2_layer = nn.LayerNorm(d_model , eps=norm_rate)

    def forward(self, in_tensor) :
        mha_tensor = self.mha_layer(in_tensor, in_tensor, in_tensor, None)
        mha_tensor = self.drop1_layer(mha_tensor)
        h_tensor = self.norm1_layer(in_tensor + mha_tensor) # residual connection

        ff_tensor = self.ff_layer(h_tensor)
        ff_tensor = self.drop2_layer(ff_tensor)
        o_tensor = self.norm2_layer(h_tensor + ff_tensor)
        return o_tensor

class Encoder(nn.Module) :
    def __init__(self, layer_size, sen_size, d_model, num_heads, hidden_size, drop_rate, norm_rate) :
        super(Encoder , self).__init__()
        self.layer_size = layer_size
    
        self.en_net = nn.Sequential()
        for i in range(layer_size) :
            en_block = EncoderBlock(sen_size, d_model, num_heads, hidden_size, drop_rate, norm_rate)
            self.en_net.add_module('Encoder_Block' + str(i) , en_block)

    def forward(self, in_tensor) :
        o_tensor = self.en_net(in_tensor)
        return o_tensor


In [42]:
class VIT(nn.Module) :
    def __init__(self , 
        layer_size , 
        class_size , 
        channel_size , 
        img_size , 
        patch_size , 
        em_dim , 
        num_heads , 
        hidden_size , 
        drop_rate , 
        norm_rate , 
        cuda_flag) :
        super(VIT , self).__init__()

        self.class_size = class_size
        self.img_size = img_size
        self.patch_size = patch_size
        self.layer_size = layer_size
        self.em_dim = em_dim
        self.num_heads = num_heads
        self.hidden_size = hidden_size
        self.drop_rate = drop_rate 
        self.norm_rate = norm_rate
        self.cuda_flag = cuda_flag
        

        p_len = int(img_size / patch_size) ** 2
        v_size = (patch_size ** 2) * channel_size
        self.p_flatten = PatchFlatten(channel_size, img_size, patch_size)
        self.p_embedding = PositionEmbedding(p_len, v_size, em_dim, cuda_flag)
        self.encoder = Encoder(layer_size, p_len+1, em_dim, num_heads, hidden_size, drop_rate, norm_rate)

        self.o_layer = nn.Linear(em_dim, class_size)

        self.init_param()

    # Xavier Initialization
    def init_param(self) :
        for p in self.parameters() :
            if p.dim() > 1 :
                nn.init.xavier_uniform_(p)

    def forward(self, tensor) :
        f_tensor = self.p_flatten(tensor) # patch faltten
        em_tensor = self.p_embedding(f_tensor) # positional embedding
        o_tensor = self.encoder(em_tensor)

        index = torch.tensor([0])

        if self.cuda_flag == True :
            index = index.cuda()

        idx_tensor = torch.index_select(o_tensor, 1 , index)
        idx_tensor = idx_tensor.squeeze(1)
        rep_tensor = self.o_layer(idx_tensor)
        return rep_tensor

## Acc & Loss function

In [43]:
def acc_fn(y_output , y_label) :
    y_acc = (torch.argmax(y_output, dim = -1) == torch.argmax(y_label, dim = -1)).float()    
    y_acc = torch.mean(y_acc)
    return y_acc

def loss_fn(y_output, y_label) :
    y_log = -F.log_softmax(y_output, -1)
    y_loss = torch.mul(y_log, y_label)
    y_sum = torch.sum(y_loss, dim=1)
    y_mean = torch.mean(y_sum)
    return y_mean

## Hyperparameter

In [44]:
# VIT Base/16
layer_size = 12
channel_size = 3
patch_size = 32
embedding_dim = 768
hidden_size = 3072
num_heads = 12
drop_rate = 1e-1
norm_rate = 1e-6

In [45]:
model = VIT(layer_size, 
            class_size, 
            channel_size, 
            img_size, 
            patch_size, 
            embedding_dim, 
            num_heads, 
            hidden_size,
            drop_rate, 
            norm_rate, 
            use_cuda).to(device)

## Logging

In [32]:
dir_path = '/content/drive/MyDrive/Model/CV/ImageClassification/VIT'

writer = SummaryWriter(os.path.join(dir_path, 'Log'))

## Optimizer & Scheduler

In [46]:
warmup_steps = 2000

def schedule_fn(epoch , lr) :
    step_num = epoch + 1
    d_model = embedding_dim
    arg1 = d_model ** (-0.5)
    arg2 = min(step_num**(-0.5) , (step_num * warmup_steps**(-1.5)))
    return (arg1 * arg2)/lr

In [47]:
dumb_lr = 1e-4

optimizer = optim.Adam(model.parameters() , lr = dumb_lr , betas = (0.9,0.999) , weight_decay = 0.1)
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch : schedule_fn(epoch,dumb_lr))

In [48]:
def progressLearning(value, endvalue, loss, acc, bar_length=50):
    percent = float(value + 1) / endvalue
    arrow = '-' * int(round(percent * bar_length)-1) + '>'
    spaces = ' ' * (bar_length - len(arrow))
    sys.stdout.write("\r[{0}] {1}/{2} \t Loss : {3:.3f} , Acc : {4:.3f}".format(arrow + spaces, value+1, endvalue, loss, acc))
    sys.stdout.flush()

In [49]:
min_loss = np.inf
epochs = 100
stop_count = 0
log_count = 0

In [None]:
for epoch in range(epochs) :
    idx = 0
    model.train()
    print('Epoch : %d/%d \t Learning Rate : %e' %(epoch, epochs, optimizer.param_groups[0]["lr"]))
    # training process
    for img_data, img_label in train_loader :
        img_data = img_data.float().to(device) / 255
        img_label = img_label.to(device)

        optimizer.zero_grad()
            
        img_data = train_transform(img_data)
        img_data, img_label = img_cutmix(img_data, img_label)
        img_data = norm_transform(img_data)
        img_out = model(img_data)
        
        loss = loss_fn(img_out, img_label)
        acc = acc_fn(img_out, img_label)

        loss.backward()
        optimizer.step()
        
        progressLearning(idx, len(train_loader), loss.item(), acc.item())

        if (idx + 1) % 100 == 0 :
            writer.add_scalar('train/loss', loss.item(), log_count)
            writer.add_scalar('train/acc', acc.item(), log_count)
            log_count += 1
        idx += 1

    # validation process
    with torch.no_grad() :
        model.eval()
        loss_eval = 0.0
        acc_eval = 0.0
        for img_data, img_label in val_loader :
            img_data = img_data.float().to(device) / 255
            img_label = img_label.to(device)

            img_data = val_transform(img_data)
            img_out = model(img_data)
        
            loss_eval += loss_fn(img_out, img_label)
            acc_eval += acc_fn(img_out, img_label)

        loss_eval /= len(val_loader)
        acc_eval /= len(val_loader)

    writer.add_scalar('val/loss', loss_eval.item(), epoch)
    writer.add_scalar('val/acc', acc_eval.item(), epoch)
    
    if loss_eval < min_loss :
        min_loss = loss_eval
        torch.save({'epoch' : (epoch) ,  
            'model_state_dict' : model.state_dict() , 
            'loss' : loss_eval.item() , 
            'acc' : acc_eval.item()} , 
            os.path.join(dir_path, 'Model', 'vit_model.pt'))        
        stop_count = 0 
    else :
        stop_count += 1
        if stop_count >= 5 :      
            print('\tTraining Early Stopped')
            break
            
    scheduler.step()
    print('\nTest Loss : %.3f \t Test Accuracy : %.3f\n' %(loss_eval, acc_eval))

Epoch : 0/100 	 Learning Rate : 4.034358e-07
[->                                                ] 134/3003 	 Loss : 7.298 , Acc : 0.000