In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchinfo import summary

from model_training import DoodleDataset, RealDataset

In [34]:
class ConvNeXtBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv1 = nn.Conv2d(dim, dim, (7, 7), padding=3, groups=dim)
        self.lin1 = nn.Linear(dim, 4 * dim)
        self.lin2 = nn.Linear(4 * dim, dim)
        self.ln = nn.LayerNorm(dim)
        self.gelu = nn.GELU()

    def forward(self, x):
        res_inp = x
        x = self.conv1(x)
        x = x.permute(0, 2, 3, 1)  # NCHW -> NHWC
        x = self.ln(x)
        x = self.lin1(x)
        x = self.lin2(x)
        x = self.gelu(x)
        x = x.permute(0, 3, 1, 2)  # NHWC -> NCHW
        out = x + res_inp

        return out


class ConvNeXt(nn.Module):
    def __init__(self, in_channels, classes, block_dims=[192, 384, 768]):
        super().__init__()
        self.blocks = nn.Sequential(
            nn.Conv2d(in_channels, block_dims[0], kernel_size=2, stride=2),
            ConvNeXtBlock(block_dims[0]),
            nn.Conv2d(block_dims[0], block_dims[1], kernel_size=2, stride=2),
            ConvNeXtBlock(block_dims[1]),
            nn.Conv2d(block_dims[1], block_dims[2], kernel_size=2, stride=2),
            ConvNeXtBlock(block_dims[2]),
        )
        self.block_dims = block_dims
        self.project = nn.Linear(block_dims[-1], classes)

    def forward(self, x, return_feats=False):
        feats = self.blocks(x)
        x = feats.view(-1, self.block_dims[-1], 8*8).mean(2)
        out = self.project(x)

        if return_feats:
            return out, feats
        
        return out

In [35]:
net = ConvNeXt(3, 9)
x = torch.rand(100, 3, 64, 64)
y = net(x)
print (y.shape)

torch.Size([100, 9])


In [36]:
print (summary(net, input_size=(512, 3, 64, 64)))

Layer (type:depth-idx)                   Output Shape              Param #
ConvNeXt                                 --                        --
├─Sequential: 1-1                        [512, 768, 8, 8]          --
│    └─Conv2d: 2-1                       [512, 192, 32, 32]        2,496
│    └─ConvNeXtBlock: 2-2                [512, 192, 32, 32]        --
│    │    └─Conv2d: 3-1                  [512, 192, 32, 32]        9,600
│    │    └─LayerNorm: 3-2               [512, 32, 32, 192]        384
│    │    └─Linear: 3-3                  [512, 32, 32, 768]        148,224
│    │    └─Linear: 3-4                  [512, 32, 32, 192]        147,648
│    │    └─GELU: 3-5                    [512, 32, 32, 192]        --
│    └─Conv2d: 2-3                       [512, 384, 16, 16]        295,296
│    └─ConvNeXtBlock: 2-4                [512, 384, 16, 16]        --
│    │    └─Conv2d: 3-6                  [512, 384, 16, 16]        19,200
│    │    └─LayerNorm: 3-7               [512, 16, 16, 384]

In [43]:
real_train_set = RealDataset(train=True)
real_val_set = RealDataset(train=False)

real_train_loader = torch.utils.data.DataLoader(real_train_set, batch_size=512, shuffle=True)
real_val_loader = torch.utils.data.DataLoader(real_val_set, batch_size=real_val_set.X.shape[0])

In [44]:
print (real_train_set.X.shape)

(11141, 64, 64, 3)


In [45]:
criterion = nn.CrossEntropyLoss()
model = ConvNeXt(3, 9)
optim = torch.optim.AdamW(model.parameters())

In [46]:
model = nn.DataParallel(model).cuda()

In [47]:
def get_accuracy(pred, label):
    pred, label = pred.cpu(), label.cpu()
    return (pred.argmax(1) == label).sum().item() / len(label)

In [49]:
epochs= 20

for epoch in range(epochs):
    total_loss = 0
    model.train()
    count = 0
    for idx, (x, y) in enumerate(real_train_loader):
        count += 1
        optim.zero_grad()
        
        x, y = x.cuda(), y.cuda()
        pred = model(x)
        loss = criterion(pred, y)

        total_loss += loss.detach().cpu().item()
        loss.backward()
        optim.step()
    
    epoch_loss = total_loss / count
    
    val_data = next(iter(real_val_loader))
    with torch.no_grad():
        model.eval()
        x, y = val_data
        val_pred = model(x)
        val_acc = get_accuracy(val_pred, y)
        
    print ("Epoch: {}, Training Loss: {}, Val Accuracy: {}".format(epoch, epoch_loss, val_acc))

Epoch: 0, Training Loss: 2.1540638370947405, Val Accuracy: 0.3816533524818239


RuntimeError: Caught RuntimeError in replica 1 on device 1.
Original Traceback (most recent call last):
  File "/home/rishabh/miniconda3/envs/Rish4243/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
    output = module(*input, **kwargs)
  File "/home/rishabh/miniconda3/envs/Rish4243/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/tmp/ipykernel_2674731/3524364667.py", line 39, in forward
    feats = self.blocks(x)
  File "/home/rishabh/miniconda3/envs/Rish4243/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/rishabh/miniconda3/envs/Rish4243/lib/python3.7/site-packages/torch/nn/modules/container.py", line 141, in forward
    input = module(input)
  File "/home/rishabh/miniconda3/envs/Rish4243/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/tmp/ipykernel_2674731/3524364667.py", line 15, in forward
    x = self.lin1(x)
  File "/home/rishabh/miniconda3/envs/Rish4243/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/rishabh/miniconda3/envs/Rish4243/lib/python3.7/site-packages/torch/nn/modules/linear.py", line 103, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: CUDA out of memory. Tried to allocate 8.16 GiB (GPU 1; 23.70 GiB total capacity; 4.27 GiB already allocated; 6.77 GiB free; 10.40 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
