In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchinfo import summary
from utils import *
from model_training import DoodleDataset, RealDataset

In [2]:
class ConvNeXtBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv1 = nn.Conv2d(dim, dim, (7, 7), padding=3, groups=dim)
        self.lin1 = nn.Linear(dim, 4 * dim)
        self.lin2 = nn.Linear(4 * dim, dim)
        self.ln = nn.LayerNorm(dim)
        self.gelu = nn.GELU()

    def forward(self, x):
        res_inp = x
        x = self.conv1(x)
        x = x.permute(0, 2, 3, 1)  # NCHW -> NHWC
        x = self.ln(x)
        x = self.lin1(x)
        x = self.lin2(x)
        x = self.gelu(x)
        x = x.permute(0, 3, 1, 2)  # NHWC -> NCHW
        out = x + res_inp

        return out


class ConvNeXt(nn.Module):
    def __init__(self, in_channels, classes, block_dims=[192, 384, 768]):
        super().__init__()
        self.blocks = nn.Sequential(
            nn.Conv2d(in_channels, block_dims[0], kernel_size=2, stride=2),
            ConvNeXtBlock(block_dims[0]),
            nn.Conv2d(block_dims[0], block_dims[1], kernel_size=2, stride=2),
            ConvNeXtBlock(block_dims[1]),
            nn.Conv2d(block_dims[1], block_dims[2], kernel_size=2, stride=2),
            ConvNeXtBlock(block_dims[2]),
        )
        self.block_dims = block_dims
        self.project = nn.Linear(block_dims[-1], classes)

    def forward(self, x, return_feats=False):
        feats = self.blocks(x)
        x = feats.view(-1, self.block_dims[-1], 8*8).mean(2)
        out = self.project(x)

        if return_feats:
            return out, feats
        
        return out

In [3]:
net = ConvNeXt(3, 9)
x = torch.rand(100, 3, 64, 64)
y = net(x)
print (y.shape)

torch.Size([100, 9])


In [4]:
print (summary(net, input_size=(512, 3, 64, 64)))

Layer (type:depth-idx)                   Output Shape              Param #
ConvNeXt                                 --                        --
├─Sequential: 1-1                        [512, 768, 8, 8]          --
│    └─Conv2d: 2-1                       [512, 192, 32, 32]        2,496
│    └─ConvNeXtBlock: 2-2                [512, 192, 32, 32]        --
│    │    └─Conv2d: 3-1                  [512, 192, 32, 32]        9,600
│    │    └─LayerNorm: 3-2               [512, 32, 32, 192]        384
│    │    └─Linear: 3-3                  [512, 32, 32, 768]        148,224
│    │    └─Linear: 3-4                  [512, 32, 32, 192]        147,648
│    │    └─GELU: 3-5                    [512, 32, 32, 192]        --
│    └─Conv2d: 2-3                       [512, 384, 16, 16]        295,296
│    └─ConvNeXtBlock: 2-4                [512, 384, 16, 16]        --
│    │    └─Conv2d: 3-6                  [512, 384, 16, 16]        19,200
│    │    └─LayerNorm: 3-7               [512, 16, 16, 384]

In [11]:
real_train_set = RealDataset(train=True)
real_val_set = RealDataset(train=False)

real_train_loader = torch.utils.data.DataLoader(real_train_set, batch_size=512, shuffle=True)
real_val_loader = torch.utils.data.DataLoader(real_val_set, batch_size=512)

In [12]:
print (real_train_set.X.shape)

(11141, 64, 64, 3)


In [13]:
criterion = nn.CrossEntropyLoss()
model = ConvNeXt(3, 9)
optim = torch.optim.AdamW(model.parameters())

In [14]:
fix_seed(0)  # zero seed by default
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"

model = nn.DataParallel(model).cuda()

In [15]:
def get_accuracy(pred, label):
    pred, label = pred.cpu(), label.cpu()
    return (pred.argmax(1) == label).sum().item() / len(label)

In [17]:
epochs= 20

for epoch in range(epochs):
    total_loss = 0
    model.train()
    count = 0
    for idx, (x, y) in enumerate(real_train_loader):
        count += 1
        optim.zero_grad()
        
        x, y = x.cuda(), y.cuda()
        pred = model(x)
        loss = criterion(pred, y)

        total_loss += loss.detach().cpu().item()
        loss.backward()
        optim.step()
    
    epoch_loss = total_loss / count
    
    total_val_acc = 0
    with torch.no_grad():
        model.eval()
        count = 0
        for idx, (x, y) in enumerate(real_val_loader):
            count += 1
            pred = model(x)
            val_acc = get_accuracy(pred, y)
            total_val_acc += val_acc
            
        avg_val_acc = total_val_acc / count
        
    print ("Epoch: {}, Training Loss: {}, Avg. Validation Accuracy: {}".format(epoch, epoch_loss, avg_val_acc))

Epoch: 0, Training Loss: 1.6636805642734875, Avg. Validation Accuracy: 0.3922643195694087
Epoch: 1, Training Loss: 1.5064737038178877, Avg. Validation Accuracy: 0.46785194591318063
Epoch: 2, Training Loss: 1.4061995527961038, Avg. Validation Accuracy: 0.5288675891709512
Epoch: 3, Training Loss: 1.2786825624379246, Avg. Validation Accuracy: 0.5662485485072447
Epoch: 4, Training Loss: 1.1857111074707725, Avg. Validation Accuracy: 0.5928820712345174
Epoch: 5, Training Loss: 1.117525420405648, Avg. Validation Accuracy: 0.6036269974000935
Epoch: 6, Training Loss: 1.0351213135502555, Avg. Validation Accuracy: 0.6339214313361766
Epoch: 7, Training Loss: 0.9688996320421045, Avg. Validation Accuracy: 0.6899125361503856
Epoch: 8, Training Loss: 0.8632748425006866, Avg. Validation Accuracy: 0.7318433318970553
Epoch: 9, Training Loss: 0.7721823399717157, Avg. Validation Accuracy: 0.7645902408565085
Epoch: 10, Training Loss: 0.6602438634092157, Avg. Validation Accuracy: 0.8196209869420426
Epoch: 11

In [18]:
def save_model(ckpt_dir, cp_name, model):
    """
    Create directory /Checkpoint under exp_data_path and save encoder as cp_name
    """
    os.makedirs(ckpt_dir, exist_ok=True)
    saving_model_path = os.path.join(ckpt_dir, cp_name)
    if isinstance(model, torch.nn.DataParallel):
        model = model.module  # convert to non-parallel form
    torch.save(model.state_dict(), saving_model_path)
    print(f'Model saved: {saving_model_path}')

In [19]:
exp_dir = f'v5_trained_real_imgs_classification/'
save_model(exp_dir, f'v5_model.pt', model)

Model saved: v5_trained_real_imgs_classification/v5_model.pt


In [20]:
!cd v5_trained_real_imgs_classification && ls

v5_model.pt
