In [1]:
import torch.nn as nn
import torch
import torch.nn.functional as F
from torch.utils import data
import torchvision
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
class FeedForwardNetwork(nn.Module):
    def __init__(self, emb_dim, hidden_dim, dropout=0.):
        super().__init__()
        self.ffn = nn.Sequential(
            nn.Linear(emb_dim, hidden_dim), 
            nn.ReLU(), 
            nn.Dropout(dropout), 
            nn.Linear(hidden_dim, emb_dim), 
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.ffn(x)

In [3]:
class MultiHeadDotProductSelfAttention(nn.Module):
    def __init__(self, emb_dim, num_heads, dropout=0.):
        super().__init__()
        self.scale = emb_dim ** -0.5
        self.num_heads = num_heads
        self.qkv = nn.Linear(emb_dim, 3 * emb_dim)
        self.out = nn.Sequential(
            nn.Linear(emb_dim, emb_dim),
            nn.Dropout(dropout)
        )
    def forward(self, x, mask=None):
        # shape of x: [batch_size, sequence_length, embedding_size]
        b, n, _, h = *x.shape, self.num_heads
        # shape of qkv: [batch_size, sequence_length, 3*embedding_size]
        qkv = self.qkv(x)
        # shape of q, k, v: [batch_size, sequence_length, embedding_size]
        q, k, v = qkv.chunk(3, dim=-1)
        # shape of q, k, v: [batch_size, num_heads, sequence_length, embedding_size / num_heads]
        q, k, v = map(lambda x:x.reshape(b, n, h, -1).transpose(2, 1), (q, k, v))
        # shape of attention_score: [batch_size, num_heads, sequence_length, sequence_length]
        attention_score = torch.matmul(q, k.transpose(2, 3)) * self.scale

        if mask is not None:
            # 最前面的[class]token永远得是有效的，所以在第一列pad一列True
            mask = F.pad(mask.flatten(1), (1, 0), value = True)
            assert mask.shape[-1] == attention_score.shape[-1], 'mask has incorrect dimensions'
            # Mask中同一行里的False只会出现在True后面!!
            mask = mask[:, None, :] * mask[:, :, None]
            attention_score.masked_fill_(~mask, float('-inf'))
            del mask
        
        attention_weight = F.softmax(attention_score, dim=-1)
        # shape of out: [batch_size, num_heads, sequence_length, embedding_size / num_heads]
        out = torch.matmul(attention_weight, v)
        out = out.transpose(2, 1).reshape(b, n, -1)
        # shape of out: [batch_size, sequence_length, embedding_size]
        out = self.out(out)
        return out
    

In [4]:
msa = MultiHeadDotProductSelfAttention(emb_dim=10, num_heads=2, dropout=0.)
a = torch.zeros(2, 5, 10)
mask = torch.tensor([[True, True, False, False],
                     [True, False, False, False]])
msa(a, mask).shape

torch.Size([2, 5, 10])

In [5]:
class EncoderBlock(nn.Module):
    def __init__(self, emb_dim, num_heads, hidden_dim, dropout=0.):
        super().__init__()
        self.num_heads = num_heads
        self.layer_norm_input = nn.LayerNorm(emb_dim)
        self.layer_norm_out = nn.LayerNorm(emb_dim)
        self.attention = MultiHeadDotProductSelfAttention(emb_dim, num_heads, dropout)
        self.ffn = FeedForwardNetwork(emb_dim, hidden_dim, dropout)
        
    def forward(self, x):
        y = self.layer_norm_input(x)
        y = self.attention(y)
        y = y + x
        z = self.layer_norm_out(y)
        z = self.ffn(z)
        return z + y

In [6]:
class Encoder(nn.Module):
    def __init__(self, N, emb_dim, num_heads, hidden_dim, dropout=0.):
        super().__init__()
        self.layers = nn.Sequential(*[EncoderBlock(emb_dim, num_heads, hidden_dim, dropout) for _ in range(N)])
        
    def forward(self, x):
        return self.layers(x)

In [7]:
class ViT(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.N = cfg.N
        self.patch_size = cfg.patch_size
        self.emb_dim = cfg.emb_dim
        self.hidden_dim = cfg.hidden_dim
        self.num_heads = cfg.num_heads
        self.dropout = cfg.dropout
        self.image_size = cfg.image_size
        self.num_channels = cfg.num_channels
        self.num_patches = (self.image_size // self.patch_size) ** 2
        self.num_classes = cfg.num_classes
        assert self.image_size % self.patch_size == 0, 'image dimensions must be divisible by the patch size'
        
        self.class_token = nn.Parameter(torch.randn(1, 1, self.emb_dim))
        self.embedding = nn.Linear(self.num_channels * self.patch_size ** 2, self.emb_dim)
        self.PE = nn.Parameter(torch.randn(1, self.num_patches+1, self.emb_dim))
        self.emb_dropout = nn.Dropout(self.dropout)
        self.transformer = Encoder(self.N, self.emb_dim, self.num_heads, self.hidden_dim, self.dropout)
        self.MLP_head = nn.Sequential(
            nn.Linear(self.emb_dim, 2*self.emb_dim),
            nn.ReLU(),
            nn.Linear(2*self.emb_dim, self.num_classes))
        
    def forward(self, x, tokenize=True):
        if tokenize:
            x = self.tokenize(x)
        b = x.shape[0]
        x = self.embedding(x)
        cls_tokens = self.class_token.repeat(b, 1, 1)
        x = torch.concat((cls_tokens, x), dim=1)
        x += self.PE
        x = self.emb_dropout(x)
        x = self.transformer(x)
        x = self.MLP_head(x[:, 0, :])
        return x
    
    def tokenize(self, x):
        b, c, h, w = x.shape
        assert h == self.image_size and w == self.image_size, 'the size of the input image is incorrect'
        x = x.chunk(self.image_size // self.patch_size, dim=2)
        patches = []
        for patch in x:
            patches += patch.chunk(self.image_size // self.patch_size, dim=3)
        # shape of x: [batch_size, num_patches, num_channels * patch_size ** 2]
        x = torch.stack(patches, dim=1).reshape(b, self.num_patches, -1)
        return x
    
    def print_num_params(self):
        print(sum(p.numel() for p in self.parameters() if p.requires_grad))


In [8]:
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

In [9]:
cifar_train = torchvision.datasets.CIFAR10(root="../data", train=True, download=True)
print(dir(cifar_train))
print(cifar_train.data.shape) # (50000, 32, 32, 3)
cifardata = cifar_train.data / 255
mean = torch.tensor(cifardata.mean(axis=(0, 1, 2)))
std = torch.tensor(cifardata.std(axis=(0, 1, 2)))
print(mean, std)

Files already downloaded and verified
['__add__', '__class__', '__class_getitem__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getitem__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__len__', '__lt__', '__module__', '__ne__', '__new__', '__orig_bases__', '__parameters__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__slots__', '__str__', '__subclasshook__', '__weakref__', '_check_integrity', '_format_transform_repr', '_is_protocol', '_load_meta', '_repr_indent', 'base_folder', 'class_to_idx', 'classes', 'data', 'download', 'extra_repr', 'filename', 'meta', 'root', 'target_transform', 'targets', 'test_list', 'tgz_md5', 'train', 'train_list', 'transform', 'transforms', 'url']
(50000, 32, 32, 3)
tensor([0.4914, 0.4822, 0.4465], dtype=torch.float64) tensor([0.2470, 0.2435, 0.2616], dtype=torch.float64)


In [10]:
train_and_valid = data.random_split(torchvision.datasets.CIFAR10(root="../data", train=True, download=True),
                                    [45000, 5000],
                                    generator=torch.Generator().manual_seed(42))

Files already downloaded and verified


In [11]:
class TrainDataset(data.Dataset):
    def __init__(self, dataset):
        super().__init__()
        self.dataset = dataset
        self.trans = transforms.Compose([transforms.ToTensor(),
                                         transforms.RandomCrop(32, padding=4),
                                         transforms.RandomHorizontalFlip(p=0.5),
                                         transforms.ConvertImageDtype(torch.float),
                                         transforms.Normalize([0.4914, 0.4822, 0.4465],
                                                              [0.2470, 0.2435, 0.2616],
                                                              inplace=True)])
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, index):
        return (self.trans(self.dataset[index][0]),
                self.dataset[index][1])

In [12]:
class TestDataset(data.Dataset):
    def __init__(self, dataset):
        super().__init__()
        self.dataset = dataset
        self.trans = transforms.Compose([transforms.ToTensor(),
                                         transforms.ConvertImageDtype(torch.float),
                                         transforms.Normalize([0.4914, 0.4822, 0.4465],
                                                              [0.2470, 0.2435, 0.2616],
                                                              inplace=True)])
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, index):
        return (self.trans(self.dataset[index][0]),
                self.dataset[index][1])

In [13]:
train_dataset = TrainDataset(train_and_valid[0])
valid_dataset = TestDataset(train_and_valid[1])
test_dataset = TestDataset(torchvision.datasets.CIFAR10(root="../data", train=False, download=True))

Files already downloaded and verified


In [14]:
def loss_acc(output, target, criterion):
    pred = output.argmax(dim=1)
    acc = ((pred == target).sum() / target.numel()).item()
    loss = criterion(output, target)
    return loss, acc

In [15]:
def lr_ratio(emb_dim, warmup_steps, cur_step):
    if cur_step == 0:
        return 0
    lr = emb_dim ** -0.5
    lr *= min(cur_step ** -0.5, (cur_step * warmup_steps ** -1.5))
    return lr

In [16]:
def get_lr(optimizer):
    return (optimizer.state_dict()['param_groups'][0]['lr'])

In [17]:
def train_ViT(net, train_dataset, valid_dataset):
    cfg = net.cfg
    with open('configs.txt', 'a') as f:
        f.write('{\n')
        for k, v in net.cfg.__dict__.items():
            f.write(k + ': ' + str(v) + "\n")
        f.write('}\n')
    train_dataloader = data.DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers)
    valid_dataloader = data.DataLoader(valid_dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers)
    writer = SummaryWriter(f"runs/ViT_CIFAR_{cfg.version}")
    optimizer = torch.optim.Adam(net.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
    warmup_lr = lambda cur_step: lr_ratio(net.emb_dim, cfg.warmup_steps, cur_step)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup_lr)
    criterion = torch.nn.CrossEntropyLoss()
    global_step = 0
    for epoch in range(cfg.num_epochs):
        net.train()
        train_loss, train_acc = [], []
        for input, target in train_dataloader:
            global_step += 1
            input, target = input.to(device), target.to(device)
            optimizer.zero_grad()
            output = net(input)
            loss, acc = loss_acc(output, target, criterion)
            train_loss.append(loss)
            train_acc.append(acc)
            loss.backward()
            writer.add_scalar('learning rate', get_lr(optimizer), global_step=global_step)
            optimizer.step()
            scheduler.step()
            writer.add_scalar('train/loss', loss.item(), global_step=global_step)
            writer.add_scalar('train/accuracy', acc, global_step=global_step)
        with torch.no_grad():
            net.eval()
            valid_loss, valid_acc = [], []
            for input, target in valid_dataloader:
                input, target = input.to(device), target.to(device)
                output = net(input)
                loss, acc = loss_acc(output, target, criterion)
                valid_loss.append(loss.item())
                valid_acc.append(acc)
            writer.add_scalar('valid/loss', sum(valid_loss) / len(valid_loss), global_step=global_step)
            writer.add_scalar('valid/accuracy', sum(valid_acc) / len(valid_acc), global_step=global_step)
        message = list(map(lambda x:sum(x) / len(x), (train_loss, train_acc, valid_loss, valid_acc)))
        print(f'epoch {epoch+1:3d}, train loss: {message[0]:8.4f}, train accuracy: {message[1]:8.4f}, valid loss: {message[2]:8.4f}, valid accuracy: {message[3]:8.4f}')
        torch.save(net.state_dict(), f'ViT_CIFAR_{cfg.version}')

In [18]:
class Configuration:
    def __init__(self, version):
        ############### model hyperparameters ###############
        self.version = version
        self.N = 8
        self.patch_size = 4
        self.emb_dim = 128
        self.hidden_dim = 256
        self.num_heads = 4
        self.dropout = 0.05
        self.image_size = 32
        self.num_channels = 3
        self.num_patches = (self.image_size // self.patch_size) ** 2
        self.num_classes = 10
        ############### train hyperparameters ###############
        self.batch_size = 256
        self.lr = 1e-1
        self.weight_decay = 5e-4
        self.warmup_steps = 1000
        self.num_epochs = 75
        self.num_workers = 0
        #####################################################
    def __str__(self):
        return self.version

In [None]:
    cfg = Configuration('Version4')
    net = ViT(cfg).to(device)
    net.print_num_params()
    train_ViT(net, train_dataset, valid_dataset)

1110154
epoch   1, train loss:   2.2402, train accuracy:   0.1540, valid loss:   2.0744, valid accuracy:   0.2178
epoch   2, train loss:   2.0106, train accuracy:   0.2448, valid loss:   1.9771, valid accuracy:   0.2638
epoch   3, train loss:   1.9153, train accuracy:   0.2878, valid loss:   1.9100, valid accuracy:   0.2823
epoch   4, train loss:   1.8166, train accuracy:   0.3250, valid loss:   1.7713, valid accuracy:   0.3560
epoch   5, train loss:   1.7142, train accuracy:   0.3654, valid loss:   1.7098, valid accuracy:   0.3781
