### Transformer using PyTorch

In [9]:
import torch
import torchvision
import math
from argparse import ArgumentParser
from pytorch_lightning import LightningModule, Trainer, LightningDataModule
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchmetrics.functional import accuracy
from einops import rearrange
from torch import nn, einsum

In [10]:


class SelfAttention(nn.Module):
    def __init__(self, dim, head=2):
        super().__init__()

        self.head = head
        self.scale = (dim // head) ** -0.5

        self.to_q = nn.Linear(dim, head * dim)
        self.to_k = nn.Linear(dim, head * dim)
        self.to_v = nn.Linear(dim, head * dim)
        self.merge = nn.Linear(head * dim, dim)

    def forward(self, x):
        q, k, v = (self.to_q(x), self.to_k(x), self.to_v(x))
        q, k, v = map(lambda t: rearrange(t, 'b n (h c) -> (b h) n c', h=self.head), (q, k, v))

        attn = einsum('b i d, b j d -> b i j', q, k) * self.scale

        attn = attn.softmax(dim = -1)

        attn = einsum('b i j, b j d -> b i d', attn, v)
        attn = rearrange(attn, '(b h) n c -> b n (h c)', h=self.head, n=x.shape[-2])
        attn = self.merge(attn)

        return attn

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False): #, attn_drop=0., proj_drop=0.):
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        #self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        #self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)   # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        #attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        #x = self.proj_drop(x)
        return x

In [11]:
class Mlp(nn.Module):
    """ MLP as used in Vision Transformer, MLP-Mixer and related networks
    """
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
      
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
       
        self.fc2 = nn.Linear(hidden_features, out_features)
       

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        return x

In [12]:
class Block(nn.Module):

    def __init__(
            self, dim, num_heads, mlp_ratio=4., qkv_bias=False, 
            act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias) 
        self.norm2 = norm_layer(dim)
        self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer) 
   

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

In [13]:
class Transformer(nn.Module):
    def __init__(self, dim, num_heads, num_blocks, mlp_ratio=4., qkv_bias=False,  
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.num_blocks = num_blocks
        self.mlp_ratio = mlp_ratio
        self.qkv_bias = qkv_bias
        self.blocks = nn.ModuleList([Block(dim, num_heads, mlp_ratio, qkv_bias, 
                                     act_layer, norm_layer) for _ in range(num_blocks)])

    def forward(self, x):
        for block in self.blocks:
            x = block(x)
        return x

In [14]:
#model = Transformer(dim=512, num_heads=8, num_blocks=6, mlp_ratio=4., qkv_bias=False, 
#                    act_layer=nn.GELU, norm_layer=nn.LayerNorm)
#print(model)

def init_weights_vit_timm(module: nn.Module, name: str = ''):
    """ ViT weight initialization, original timm impl (for reproducibility) """
    if isinstance(module, nn.Linear):
        nn.init.kaiming_uniform_(module, a=math.sqrt(5))
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    elif hasattr(module, 'init_weights'):
        module.init_weights()

In [17]:

class LitTransformer(LightningModule):
    def __init__(self, num_classes=10, lr=0.001, max_epochs=30, depth=12, embed_dim=64,
                 head=4, patch_dim=192, seqlen=16, **kwargs):
        super().__init__()
        self.save_hyperparameters()
        self.encoder = Transformer(dim=embed_dim, num_heads=head, num_blocks=depth, mlp_ratio=4., qkv_bias=False,
                                   act_layer=nn.GELU, norm_layer=nn.LayerNorm)
        self.embed = torch.nn.Linear(patch_dim, embed_dim)
        self.merge = nn.Conv1d(embed_dim, embed_dim,
                               kernel_size=3, padding=1, bias=False)
        dim = seqlen * embed_dim
        self.fc = nn.Linear(dim, num_classes)
        self.loss = torch.nn.CrossEntropyLoss()

    def forward(self, x):
        x = self.embed(x)

        x = rearrange(x, 'b n c -> b c n')
        x = self.merge(x)
        x = rearrange(x, 'b c n -> b n c')

        #x = self.merge(x)
        x = self.encoder(x)
        x = x.flatten(start_dim=1)
        x = self.fc(x)
        return x
    
    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=self.hparams.lr)
        # this decays the learning rate to 0 after max_epochs using cosine annealing
        scheduler = CosineAnnealingLR(optimizer, T_max=self.hparams.max_epochs)
        return [optimizer], [scheduler]

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss(y_hat, y)
        return loss
    
    #def training_epoch_end(self, outputs):
    #    avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
        #self.log("train_loss", avg_loss, on_epoch=True, prog_bar=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.loss(y_hat, y)
        acc = accuracy(y_hat, y)
        #self.log("test_loss", loss, on_step=True, on_epoch=False)
        #self.log("test_acc", acc, on_epoch=False, prog_bar=True)
        return {"y_hat": y_hat, "test_loss": loss, "test_acc": acc}

    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x["test_loss"] for x in outputs]).mean()
        avg_acc = torch.stack([x["test_acc"] for x in outputs]).mean()
        self.log("test_loss", avg_loss, on_epoch=True, prog_bar=True)
        self.log("test_acc", avg_acc*100., on_epoch=True, prog_bar=True)

    def validation_step(self, batch, batch_idx):
        return self.test_step(batch, batch_idx)

    def validation_epoch_end(self, outputs):
        return self.test_epoch_end(outputs)


# a lightning data module for cifar 10 dataset
class LitCifar10(LightningDataModule):
    def __init__(self, batch_size=32, num_workers=32, patch_num=4, **kwargs):
        super().__init__()
        self.batch_size = batch_size
        self.patch_num = patch_num
        self.num_workers = num_workers

    def prepare_data(self):
        self.train_set = torchvision.datasets.cifar.CIFAR10(root='./data', train=True,
                                                            download=True, transform=torchvision.transforms.ToTensor())
        self.test_set = torchvision.datasets.cifar.CIFAR10(root='./data', train=False,
                                                           download=True, transform=torchvision.transforms.ToTensor())

    def collate_fn(self, batch):
        x, y = zip(*batch)
        x = torch.stack(x, dim=0)
        y = torch.LongTensor(y)
        x = rearrange(x, 'b c (p1 h) (p2 w) -> b (p1 p2) (c h w)', p1=self.patch_num, p2=self.patch_num)
        return x, y

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_set, batch_size=self.batch_size, 
                                        shuffle=True, collate_fn=self.collate_fn,
                                        num_workers=self.num_workers)

    def test_dataloader(self):
        return torch.utils.data.DataLoader(self.test_set, batch_size=self.batch_size, 
                                        shuffle=False, collate_fn=self.collate_fn,
                                        num_workers=self.num_workers)

    def val_dataloader(self):
        return self.test_dataloader()

class LitDataModule(LightningDataModule):
    def __init__(self, batch_size=32, **kwargs):
        super().__init__()
        self.batch_size = batch_size

    def setup(self, stage=None):
        self.train_dataset = torch.rand(self.batch_size, 3, 224, 224)
        self.val_dataset = torch.rand(self.batch_size, 3, 224, 224)

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=True)

    def test_dataloader(self):
        return torch.utils.data.DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=True)

def get_args():
    # 66.1% cifar10 81.3K (fuse + encoder)
    parser = ArgumentParser(description='PyTorch Transformer')
    parser.add_argument('--depth', type=int, default=12, help='depth')
    parser.add_argument('--embed_dim', type=int, default=64, help='embedding dimension')
    parser.add_argument('--kernel_size', type=int, default=3, help='kernel size')
    parser.add_argument('--num_heads', type=int, default=4, help='num_heads')

    parser.add_argument('--patch_num', type=int, default=16, help='patch_num')

    parser.add_argument('--batch_size', type=int, default=128, metavar='N',
                        help='input batch size for training (default: )')
    parser.add_argument('--max-epochs', type=int, default=30, metavar='N',
                        help='number of epochs to train (default: 0)')
    parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
                        help='learning rate (default: 0.0)')

    parser.add_argument('--accelerator', default='gpu', type=str, metavar='N')
    parser.add_argument('--devices', default=1, type=int, metavar='N')
    parser.add_argument('--dataset', default='cifar10', type=str, metavar='N')
    parser.add_argument('--num_workers', default=4, type=int, metavar='N')
    args = parser.parse_args("")
    return args


if __name__ == "__main__":
    args = get_args()

    datamodule = LitCifar10(batch_size=args.batch_size,
                            patch_num=args.patch_num, num_workers=args.num_workers * args.devices)
    datamodule.prepare_data()

    data = iter(datamodule.train_dataloader()).next()
    patch_dim = data[0].shape[-1]
    seqlen = data[0].shape[-2]


    model = LitTransformer(num_classes=10, lr=args.lr, epochs=args.max_epochs, 
                           depth=args.depth, embed_dim=args.embed_dim, head=args.num_heads,
                           patch_dim=patch_dim, seqlen=seqlen,)

    trainer = Trainer(accelerator=args.accelerator, devices=args.devices,
                      max_epochs=args.max_epochs, precision=16 if args.accelerator == 'gpu' else 32,)
    trainer.fit(model, datamodule=datamodule)
    #trainer.test(model, datamodule=datamodule)

Files already downloaded and verified
Files already downloaded and verified


Using 16bit native Automatic Mixed Precision (AMP)
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Files already downloaded and verified
Files already downloaded and verified


Missing logger folder: /home/rowel/github/roatienza/Deep-Learning-Experiments/versions/2022/transformer/python/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name    | Type             | Params
---------------------------------------------
0 | encoder | Transformer      | 597 K 
1 | embed   | Linear           | 832   
2 | merge   | Conv1d           | 12.3 K
3 | fc      | Linear           | 163 K 
4 | loss    | CrossEntropyLoss | 0     
---------------------------------------------
774 K     Trainable params
0         Non-trainable params
774 K     Total params
1.549     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

RuntimeError: Given groups=1, weight of size [64, 64, 3], expected input[128, 256, 64] to have 64 channels, but got 256 channels instead