## CNN+Vit 1 with Pytorch

#### Adding Convolutions - 1: On the patches
Instead of having a Linear Projection from the patches being fed into the Transformer encoder, we can convolve over the patches then feed them into the encoder. We can simulate this by having both kernel size and stride be equal to the size of the patches.

In [35]:
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision import transforms
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from tqdm import tqdm
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

In [36]:
TRAIN = r'C:\Users\roysu\Desktop\Project\test_folder\val'
VAL = r'C:\Users\roysu\Desktop\Project\test_folder\train'

# TRAIN = "/home/udai/Desktop/Data/train_1"
# VAL = "/home/udai/Desktop/Data/test_1"

INPUT_HEIGHT = 256
INPUT_WIDTH = 256

BATCH_SIZE = 64
VAL_SPLIT = 0.1
EPOCH=25

In [37]:
resize = transforms.Resize(size=(INPUT_HEIGHT,INPUT_WIDTH))
hFlip = transforms.RandomHorizontalFlip(p=0.25)
vFlip = transforms.RandomVerticalFlip(p=0.25)
rotate = transforms.RandomRotation(degrees=15)

trainTransforms = transforms.Compose([resize, hFlip, vFlip, rotate,transforms.ToTensor()])
valTransforms = transforms.Compose([resize, transforms.ToTensor()])

trainDataset = ImageFolder(root=TRAIN,transform=trainTransforms,)
valDataset = ImageFolder(root=VAL, transform=valTransforms)

print("Training dataset contains {} samples...".format(len(trainDataset)))
print("Validation dataset contains {} samples...".format(len(valDataset)))

trainloader = DataLoader(trainDataset,batch_size=BATCH_SIZE, shuffle=True)
testloader = DataLoader(valDataset, batch_size=BATCH_SIZE)


Training dataset contains 19033 samples...
Validation dataset contains 76601 samples...


In [38]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 
print(device)
criterion = nn.CrossEntropyLoss()
scaler = torch.cuda.amp.GradScaler(enabled=False)

cpu


In [39]:
def train(net, epoch):
    print('\nEpoch: %d' % epoch)
    optimizer = optim.Adam(net.parameters(), lr=1e-4)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    print("Start Training")
    for batch_idx, (inputs, targets) in enumerate(tqdm(trainloader)):
        inputs, targets = inputs.to(device), targets.to(device)
        # Train with amp
        with torch.cuda.amp.autocast(enabled=False):
            outputs = net(inputs)
            loss = criterion(outputs, targets)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

#         if batch_idx == len(trainloader) - 1:
#               print('train', batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
#                 % (train_loss/(batch_idx+1), 100.*correct/total, correct, total))

    return round(train_loss/(batch_idx+1),2),round(100.*correct/total,2)


def test(net, epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(tqdm(testloader)):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

#             if batch_idx == len(testloader) - 1:
#                   print('test', batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
#                     % (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
    acc = round(100.*correct/total,2)

    return round(test_loss,2), acc

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

In [40]:
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
          nn.Linear(dim, hidden_dim),
          nn.GELU(),
          nn.Dropout(dropout),
          nn.Linear(hidden_dim, dim),
          nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

In [41]:
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
          nn.Linear(inner_dim, dim),
          nn.Dropout(dropout)
          ) if project_out else nn.Identity()
    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

In [42]:
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
                ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
            return x

In [43]:
class ViTConv1(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width) # 64
        patch_dim = channels * patch_height * patch_width # 48
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = nn.Sequential(
          nn.Conv2d(3, 3, kernel_size=patch_height, stride=patch_height), # -> 512, 3, 8, 8
          Rearrange('b c h w -> b (h w) c'), # -> 512 64 3
          nn.Linear(3, dim),
          )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape
        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)

In [44]:
net2 = ViTConv1(
    image_size = INPUT_HEIGHT,
    patch_size = 16,
    num_classes = 2,
    dim = 512,
    depth = 6,
    heads = 8,
    mlp_dim = 512,
    dropout = 0.1,
    emb_dropout = 0.1
)
net2 = torch.nn.DataParallel(net2)

In [45]:
from torchsummary import summary
summary(net2,(3,256,256))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 3, 16, 16]           2,307
         Rearrange-2               [-1, 256, 3]               0
            Linear-3             [-1, 256, 512]           2,048
           Dropout-4             [-1, 257, 512]               0
         LayerNorm-5             [-1, 257, 512]           1,024
            Linear-6            [-1, 257, 1536]         786,432
           Softmax-7          [-1, 8, 257, 257]               0
            Linear-8             [-1, 257, 512]         262,656
           Dropout-9             [-1, 257, 512]               0
        Attention-10             [-1, 257, 512]               0
          PreNorm-11             [-1, 257, 512]               0
        LayerNorm-12             [-1, 257, 512]           1,024
           Linear-13             [-1, 257, 512]         262,656
             GELU-14             [-1, 2

In [46]:
import datetime
import os
time=datetime.datetime.now()
output_folder="CNN+VIT_Output_at_{}-{}-{}_H{}M{}".format(time.day,time.month,time.year,time.hour,time.minute)

os.mkdir(output_folder)
cwd=os.getcwd()


for i in range(EPOCH):
    t_loss,t_acc=train(net2, i)
    v_loss,v_acc=test(net2, i)
    print(f"Epoch : {i} Train_loss : {t_loss} Train_acc : {t_acc} Test_loss : {v_loss} Test_acc {v_acc} ")
    with open(file=os.path.join(cwd,output_folder,"model_logs.txt"), mode="a") as log:
        log.write(f"Epoch : {i} Train_loss : {t_loss} Train_acc : {t_acc} Test_loss : {v_loss} Test_acc {v_acc} \n")
    torch.save(net,os.path.join(cwd,output_folder,f"model_at_epoch_{i}.model"))


Epoch: 0
Start Training


  6%|████▉                                                                            | 18/298 [02:55<45:37,  9.78s/it]


KeyboardInterrupt: 