In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import random
import torchvision.transforms.functional as TF

In [2]:
from torch.utils.data.dataloader import default_collate

In [3]:
def randomSizeCollate(batch):
    data, target = zip(*batch)
    
    size = random.choice((28, 30, 34, 36))
    
    train_transform = transforms.Compose(
    [
     torchvision.transforms.RandomCrop(32, padding=4),
     torchvision.transforms.Resize(size),
     torchvision.transforms.RandomHorizontalFlip(),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    
    data = [train_transform(i) for i in data]
    
    return default_collate(list(zip(data, target)))
    

In [4]:
test_transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

train_transform = transforms.Compose(
    [
     torchvision.transforms.RandomCrop(32, padding=4),
#      torchvision.transforms.Resize(),
     torchvision.transforms.RandomHorizontalFlip(),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])


batch_size = 64

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=train_transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=16)#, collate_fn=randomSizeCollate)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=test_transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=16)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [5]:
import torch

import torch.nn as nn

from torch import Tensor
from typing import Type, Any, Callable, Union, List, Optional, Tuple

In [6]:
class ConvNormAct(nn.Module):
    
    def __init__(self,
                 inPlanes: int,
                 outPlanes: int,
                 kernel: int,
                 padding: int = 0,
                 dilation: int = 1,
                 groups: int = 1
                )->None:
    
        super(ConvNormAct, self).__init__()
        
        self.conv = nn.Conv2d(inPlanes, outPlanes, kernel,
                             padding=padding, dilation=dilation, groups=groups)
        self.norm = nn.InstanceNorm2d(outPlanes)
        self.act = nn.ELU()
        
        
    def forward(self, x:Tensor)->Tensor:
        
        x = self.conv(x)
        x = self.norm(x)
        x = self.act(x)
        
        return x
    
class ConvNorm(nn.Module):
    
    def __init__(self,
                 inPlanes: int,
                 outPlanes: int,
                 kernel: int,
                 padding: int = 0,
                 dilation: int = 1,
                 groups: int = 1
                )->None:
    
        super(ConvNorm, self).__init__()
        
        self.conv = nn.Conv2d(inPlanes, outPlanes, kernel,
                             padding=padding, dilation=dilation, groups=groups)
        self.norm = nn.InstanceNorm2d(outPlanes)
        
        
    def forward(self, x:Tensor)->Tensor:
        
        x = self.conv(x)
        x = self.norm(x)
        
        return x
    
    
    
class ConvAttn(nn.Module):
    
    def __init__(self, 
                 inplanes: int, 
                 qdim: int,
                 groups: int
                )-> None:
        
        super(ConvAttn, self).__init__()
        
        self.Q = ConvNorm(inplanes, qdim, 5, dilation=2)
        self.K = ConvNorm(inplanes, inplanes, 5, dilation=2, groups=groups) # does padding help?
        self.qdim = qdim
        
        
    def forward(self, V:Tensor)->Tensor:
        K = self.K(V)
        Q = self.Q(V)
        
        
        Q = torch.flatten(Q, 2, 3)
        K = torch.flatten(K, 2, 3)
        K = torch.transpose(K, 1, 2)
        
        device = next(self.parameters()).device
        sqrt_dK = torch.sqrt(torch.tensor(K.shape[-2], device=device, dtype=torch.float))
        
        attn = torch.matmul(Q, K)/sqrt_dK
        attn = torch.nn.functional.softmax(attn, dim=2)
        
        vsh = V.shape 
        V = torch.flatten(V, start_dim=2, end_dim=3)
        attn = torch.matmul(attn, V)
        
        return attn.view(vsh[0], self.qdim, vsh[2], vsh[3])  
        

class StaticValueAttn(nn.Module):
    
    def __init__(self,
                 nValues: int,
                 dimV: int,
                )-> None:
        
        super(StaticValueAttn, self).__init__()
        
        values = torch.rand(1, nValues, dimV)
        self.values = nn.Parameter(values)
        self.act = nn.ELU()
        
        # nn.init.kaiming_normal_(self.values, mode='fan_out')
        
    
    def forward(self,
                Q: Tensor,
                K: Tensor
               )->Tensor:
        
        Q = torch.flatten(Q, 2, 3)
        K = torch.flatten(K, 2, 3)
        K = torch.transpose(K, 1, 2)
        
        device = next(self.parameters()).device
        sqrt_dK = torch.sqrt(torch.tensor(K.shape[-2], device=device, dtype=torch.float))
        
        attn = torch.matmul(Q, K)/sqrt_dK
        attn = torch.nn.functional.softmax(attn, dim=2)
        attn = torch.matmul(attn, self.values)
        
        return attn

In [7]:
class RAtNet(nn.Module):
    
    def __init__(self, 
                 blocks: List[Tuple[int, int]]
                )->None:
        super(RAtNet, self).__init__()
        
        conv = [ConvNormAct(blocks[0][0], blocks[0][1], 5, padding=2)]
        attn = []
        for block in blocks[1:]:
            conv.append(ConvNormAct(block[0], block[1], 5, padding=2))
            attn.append(ConvAttn(block[1], block[0], block[1]//4))
            
        self.conv = nn.ModuleList(conv)
        self.attn = nn.ModuleList(attn)
        
        self.Q = ConvNorm(blocks[-1][0], 1, 5, dilation=2)
        self.K = ConvNorm(blocks[-1][0], 128, 5, dilation=2)
        
        self.valueAttn = StaticValueAttn(128, 128)
        self.fc = nn.Linear(128, 10)
    
        
        
        
    def forward(self, x)->Tensor:
        
        x = self.conv[0](x)
                
        for i, conv in enumerate(self.conv[1:]):
            x_ = conv(x)
            x_ = self.attn[i](x_)
            x = x + x_
            
        Q = self.Q(x)
        K = self.K(x)
                
        fc = self.valueAttn(Q, K)
        
        fc = torch.squeeze(fc, dim=1)
        
        return self.fc(fc)
    

In [8]:
from torchvision.models.resnet import *
from torchvision.models.resnet import BasicBlock, Bottleneck

In [9]:
class ResAttClass(ResNet):
    
    def __init__(self, *args, **kwargs)->None:
        super(ResAttClass, self).__init__(*args, **kwargs)
        
        del self.avgpool
        del self.fc
        
        self.Q = ConvNorm(512, 1, 3, padding=1)
        self.K = ConvNorm(512, 128, 3, padding=1)
        
        self.valueAttn = StaticValueAttn(128, 10)
        
    
    def _forward_impl(self, x: Tensor) -> Tensor:
        # See note [TorchScript super()]
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        Q = self.Q(x)
        K = self.K(x)
                
        x = self.valueAttn(Q, K)

        return torch.sum(x, dim=1)
        # return x
        

In [10]:
def _resAttClass(
    arch: str,
    block: Type[Union[BasicBlock, Bottleneck]],
    layers: List[int],
    pretrained: bool,
    progress: bool,
    **kwargs: Any
) -> ResNet:
    model = ResAttClass(block, layers, **kwargs)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls[arch],
                                              progress=progress)
        model.load_state_dict(state_dict)
        
    return model

In [11]:
def resnet18CIFAR10():
    model = torchvision.models.resnet18(pretrained=False, num_classes=10)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    model.maxpool = nn.Identity()
    
    return model

In [12]:
import pytorch_lightning as pl
import torchvision
import torch.nn as nn
import torch.nn.functional as F

from torchmetrics.functional import accuracy
from torch.optim import Adam

class BoilerPlate(pl.LightningModule):
    def __init__(self, train_l, val_l, blocks) -> None:
        super(BoilerPlate, self).__init__()

        self.train_l = train_l
        self.val_l = val_l
        
        # self.model = RAtNet(blocks)
        
        self.model = resnet18CIFAR10()
        
        # self.model = _resAttClass('resnet18', BasicBlock, [2, 2, 2, 2], pretrained=False, progress=True, num_classes=10)
        # self.model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        # self.model.maxpool = nn.Identity()



    def forward(self, x):
        out = self.model(x)
        
        return F.log_softmax(out, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        
        loss = F.nll_loss(logits, y)
        self.log("train_loss", loss)
        
        return loss

    def evaluate(self, batch, stage=None):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)

        if stage:
            self.log(f"{stage}_loss", loss, prog_bar=True)
            self.log(f"{stage}_acc", acc, prog_bar=True)

    def validation_step(self, batch, batch_idx):
        self.evaluate(batch, "val")


    def configure_optimizers(self):
        return Adam([p for p in self.parameters() if p.requires_grad], lr=0.02, eps=1e-08)


    def train_dataloader(self):
        return self.train_l

    def val_dataloader(self):
        return self.val_l

        

In [13]:
blocks = [(3, 16), (16, 64), (16, 64), (16, 128), (16, 128)]

In [14]:
model = BoilerPlate(trainloader, testloader, blocks)
# model = resnet18CIFAR10()

In [16]:
trainer = pl.Trainer(
    progress_bar_refresh_rate=10,
    max_epochs=30,
    gpus=1,
    logger=pl.loggers.TensorBoardLogger("lightning_logs/", name="resnet18_cifar10"),
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [17]:
trainer.fit(model)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [2]

  | Name  | Type   | Params
---------------------------------
0 | model | ResNet | 11.2 M
---------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.696    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

In [18]:
test_transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

train_transform = transforms.Compose(
    [
     torchvision.transforms.RandomCrop(32, padding=4),
     torchvision.transforms.RandomHorizontalFlip(),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])


batch_size = 64

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=train_transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=16)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=test_transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=16)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


Files already downloaded and verified
Files already downloaded and verified


In [19]:
import numpy as np
from sklearn.metrics import classification_report

In [20]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)
model.eval()
preds, labels = [], []
for batch in testloader:
    x, y = batch
    x = x.to(device)
    logits = model(x)
    y_pred = torch.argmax(logits, dim=1)
    
    preds.append(y_pred.cpu().numpy())
    labels.append(y.cpu().numpy())

In [21]:
preds = np.concatenate(preds)
labels = np.concatenate(labels)

In [22]:
print(classification_report(labels, preds, target_names=classes))

              precision    recall  f1-score   support

       plane       0.90      0.89      0.90      1000
         car       0.97      0.94      0.96      1000
        bird       0.80      0.90      0.85      1000
         cat       0.85      0.73      0.79      1000
        deer       0.88      0.90      0.89      1000
         dog       0.82      0.87      0.84      1000
        frog       0.93      0.91      0.92      1000
       horse       0.92      0.91      0.91      1000
        ship       0.96      0.92      0.94      1000
       truck       0.92      0.96      0.94      1000

    accuracy                           0.89     10000
   macro avg       0.89      0.89      0.89     10000
weighted avg       0.89      0.89      0.89     10000

