In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchinfo import summary
from utils import *
from model_training import DoodleDataset, RealDataset

In [18]:
class ConvNeXtBlock2(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv1 = nn.Conv2d(dim, dim, (7, 7), padding=3, groups=dim)
        self.lin1 = nn.Linear(dim, 4 * dim)
        self.lin2 = nn.Linear(4 * dim, dim)
        self.ln = nn.LayerNorm(dim)
        self.gelu = nn.GELU()
    def forward(self, x):
        res_inp = x
        x = self.conv1(x)
        x = x.permute(0, 2, 3, 1)  # NCHW -> NHWC
        x = self.ln(x)
        x = self.lin1(x)
        x = self.gelu(x)
        x = self.lin2(x)
        x = self.gelu(x)
        x = x.permute(0, 3, 1, 2)  # NHWC -> NCHW
        out = x + res_inp
        return out

class ConvNeXt(nn.Module):
    def __init__(self, n_channels, n_classes=9, dropout=0.2, block_dims=[192, 384, 768]):
        super().__init__()
        self.blocks = nn.Sequential(
            nn.Conv2d(n_channels, block_dims[0], kernel_size=2, stride=2),
            ConvNeXtBlock2(block_dims[0]),
            nn.Conv2d(block_dims[0], block_dims[1], kernel_size=2, stride=2),
            ConvNeXtBlock2(block_dims[1]),
            nn.Conv2d(block_dims[1], block_dims[2], kernel_size=2, stride=2),
            ConvNeXtBlock2(block_dims[2]),
            nn.AdaptiveAvgPool2d((1,1))
        )
        self.flatten = nn.Flatten(1)
        self.block_dims = block_dims
        self.project = nn.Linear(block_dims[-1], n_classes)

    def forward(self, x, return_feats=False):
        x = self.blocks(x)
        feats = self.flatten(x)
        print ("feats size: ", feats.shape)
        out = self.project(feats)
        if return_feats:
            return out, feats
        return out

In [19]:
net = ConvNeXt(3, 9)
x = torch.rand(100, 3, 64, 64)
y, feats = net(x, return_feats=True)
print (y.shape, feats.shape)

feats size:  torch.Size([100, 768])
torch.Size([100, 9]) torch.Size([100, 768])


In [5]:
print (summary(net, input_size=(512, 3, 64, 64)))

Layer (type:depth-idx)                   Output Shape              Param #
ConvNeXt                                 --                        --
├─Sequential: 1-1                        [512, 768, 8, 8]          --
│    └─Conv2d: 2-1                       [512, 192, 32, 32]        2,496
│    └─ConvNeXtBlock: 2-2                [512, 192, 32, 32]        --
│    │    └─Conv2d: 3-1                  [512, 192, 32, 32]        9,600
│    │    └─LayerNorm: 3-2               [512, 32, 32, 192]        384
│    │    └─Linear: 3-3                  [512, 32, 32, 768]        148,224
│    │    └─Linear: 3-4                  [512, 32, 32, 192]        147,648
│    │    └─GELU: 3-5                    [512, 32, 32, 192]        --
│    └─Conv2d: 2-3                       [512, 384, 16, 16]        295,296
│    └─ConvNeXtBlock: 2-4                [512, 384, 16, 16]        --
│    │    └─Conv2d: 3-6                  [512, 384, 16, 16]        19,200
│    │    └─LayerNorm: 3-7               [512, 16, 16, 384]

In [6]:
real_train_set = RealDataset(train=True)
real_val_set = RealDataset(train=False)

real_train_loader = torch.utils.data.DataLoader(real_train_set, batch_size=512, shuffle=True)
real_val_loader = torch.utils.data.DataLoader(real_val_set, batch_size=512)

Loaded dataset at 'dataset/cifar/cifar.npy'.
Loaded dataset at 'dataset/sketchy/sketchy_real.npy'.
Loaded dataset at 'dataset/google_images/google_real.npy'.
Loaded dataset at 'dataset/cifar/cifar.npy'.
Loaded dataset at 'dataset/sketchy/sketchy_real.npy'.
Loaded dataset at 'dataset/google_images/google_real.npy'.


In [7]:
print (real_train_set.X.shape)

(11141, 64, 64, 3)


In [8]:
criterion = nn.CrossEntropyLoss()
model = ConvNeXt(3, 9)
optim = torch.optim.AdamW(model.parameters())

In [9]:
fix_seed(0)  # zero seed by default
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"

model = nn.DataParallel(model).cuda()

In [10]:
def get_accuracy(pred, label):
    pred, label = pred.cpu(), label.cpu()
    return (pred.argmax(1) == label).sum().item() / len(label)

In [11]:
epochs= 20

for epoch in range(epochs):
    total_loss = 0
    model.train()
    count = 0
    for idx, (x, y) in enumerate(real_train_loader):
        count += 1
        optim.zero_grad()
        
        x, y = x.cuda(), y.cuda()
        pred = model(x)
        loss = criterion(pred, y)

        total_loss += loss.detach().cpu().item()
        loss.backward()
        optim.step()
    
    epoch_loss = total_loss / count
    
    total_val_acc = 0
    with torch.no_grad():
        model.eval()
        count = 0
        for idx, (x, y) in enumerate(real_val_loader):
            count += 1
            pred = model(x)
            val_acc = get_accuracy(pred, y)
            total_val_acc += val_acc
            
        avg_val_acc = total_val_acc / count
        
    print ("Epoch: {}, Training Loss: {}, Avg. Validation Accuracy: {}".format(epoch, epoch_loss, avg_val_acc))

Epoch: 0, Training Loss: 2.1568497852845625, Avg. Validation Accuracy: 0.3599145171184856
Epoch: 1, Training Loss: 1.6310477690263228, Avg. Validation Accuracy: 0.43760475395828463
Epoch: 2, Training Loss: 1.484411207112399, Avg. Validation Accuracy: 0.5142387787742463
Epoch: 3, Training Loss: 1.3478566733273594, Avg. Validation Accuracy: 0.5518443086877775
Epoch: 4, Training Loss: 1.2495673190463672, Avg. Validation Accuracy: 0.5675514686550596
Epoch: 5, Training Loss: 1.1609799699349836, Avg. Validation Accuracy: 0.5801586874123627
Epoch: 6, Training Loss: 1.0637939978729596, Avg. Validation Accuracy: 0.6205451039962608
Epoch: 7, Training Loss: 1.0163520926778966, Avg. Validation Accuracy: 0.6147414151963075
Epoch: 8, Training Loss: 0.937765511599454, Avg. Validation Accuracy: 0.6923878333868895
Epoch: 9, Training Loss: 0.8219836245883595, Avg. Validation Accuracy: 0.752622728733349
Epoch: 10, Training Loss: 0.7162175693295219, Avg. Validation Accuracy: 0.7968396255696425
Epoch: 11, 

In [12]:
def save_model(ckpt_dir, cp_name, model):
    """
    Create directory /Checkpoint under exp_data_path and save encoder as cp_name
    """
    os.makedirs(ckpt_dir, exist_ok=True)
    saving_model_path = os.path.join(ckpt_dir, cp_name)
    if isinstance(model, torch.nn.DataParallel):
        model = model.module # convert to non-parallel form
    torch.save(model.state_dict(), saving_model_path)
    print(f'Model saved: {saving_model_path}')

In [13]:
exp_dir = f'v5_trained_real_imgs_classification/'
save_model(exp_dir, f'v5_model.pt', model)

Model saved: v5_trained_real_imgs_classification/v5_model.pt


In [14]:
!cd v5_trained_real_imgs_classification && ls

v5_model.pt


In [22]:
doodle_train_set = DoodleDataset(train=True)
doodle_val_set = DoodleDataset(train=False)

doodle_train_loader = torch.utils.data.DataLoader(doodle_train_set, batch_size=128, shuffle=True)
doodle_val_loader = torch.utils.data.DataLoader(doodle_val_set, batch_size=128)

Loaded dataset at 'dataset/sketchy/sketchy_doodle.npy'.
Loaded dataset at 'dataset/tuberlin/tuberlin.npy'.
Loaded dataset at 'dataset/google_images/google_doodles.npy'.
Loaded dataset at 'dataset/sketchy/sketchy_doodle.npy'.
Loaded dataset at 'dataset/tuberlin/tuberlin.npy'.
Loaded dataset at 'dataset/google_images/google_doodles.npy'.


In [23]:
criterion = nn.CrossEntropyLoss()
model = ConvNeXt(1, 9)
optim = torch.optim.AdamW(model.parameters())

In [24]:
fix_seed(0)  # zero seed by default
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"

model = nn.DataParallel(model).cuda()

In [None]:
epochs= 20

for epoch in range(epochs):
    total_loss = 0
    model.train()
    count = 0
    for idx, (x, y) in enumerate(doodle_train_loader):
        count += 1
        optim.zero_grad()
        
        x, y = x.cuda(), y.cuda()
        pred = model(x)
        loss = criterion(pred, y)

        total_loss += loss.detach().cpu().item()
        loss.backward()
        optim.step()
    
    epoch_loss = total_loss / count
    
    total_val_acc = 0
    with torch.no_grad():
        model.eval()
        count = 0
        for idx, (x, y) in enumerate(doodle_val_loader):
            count += 1
            pred = model(x)
            val_acc = get_accuracy(pred, y)
            total_val_acc += val_acc
            
        avg_val_acc = total_val_acc / count
        
    print ("Epoch: {}, Training Loss: {}, Avg. Validation Accuracy: {}".format(epoch, epoch_loss, avg_val_acc))

Epoch: 0, Training Loss: 3.5000380788530623, Avg. Validation Accuracy: 0.12388392857142858
Epoch: 1, Training Loss: 2.2491656371525357, Avg. Validation Accuracy: 0.12834821428571427
Epoch: 2, Training Loss: 2.2187007665634155, Avg. Validation Accuracy: 0.1640625
Epoch: 3, Training Loss: 2.162825516292027, Avg. Validation Accuracy: 0.16350446428571427
Epoch: 4, Training Loss: 2.1408002887453352, Avg. Validation Accuracy: 0.20051316273932254
Epoch: 5, Training Loss: 2.1166522332600186, Avg. Validation Accuracy: 0.15513392857142858
Epoch: 6, Training Loss: 2.0968960353306363, Avg. Validation Accuracy: 0.22544642857142858
Epoch: 7, Training Loss: 2.0451242923736572, Avg. Validation Accuracy: 0.19363839285714285
