In [66]:
import torch.nn as nn
import torch.utils.data.dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
from losses import compute_contrastive_loss_from_feats
from utils import * 
from models import *

from dataset import ImageDataset
from training_config import doodles, reals, doodle_size, real_size, NUM_CLASSES

fix_seed(0) # zero seed by default

In [35]:
ckpt_dir = 'exp_data'

In [54]:
class ConvBlock(nn.Module):
    def __init__(self, inchannels, outchannels, kernel, stride, padding=0, bias=True):
        super().__init__()        
        self.block = nn.Sequential(
                        nn.Conv2d(
                            inchannels, 
                            outchannels, 
                            kernel_size=kernel, 
                            stride=stride, 
                            padding=padding, 
                            bias=bias
                        ),
                        nn.BatchNorm2d(outchannels),
                        nn.ReLU(inplace=True)
                    )
        
    def forward(self, x):
        return self.block(x)

In [63]:
class ConvNet(nn.Module):
    CHANNELS = [64, 128, 192]
    POOL = (1, 1)

    def __init__(self, in_c, num_classes, dropout=0.2):
        super().__init__()
        layer1 = ConvBlock(in_c, self.CHANNELS[0], kernel=3, stride=2, padding=1, bias=True)
        layer2 = ConvBlock(self.CHANNELS[0], self.CHANNELS[1], kernel=3, stride=2, padding=1, bias=True)
        layer3 = ConvBlock(self.CHANNELS[1], self.CHANNELS[2], kernel=3, stride=2, padding=1, bias=True)
        pool = nn.AdaptiveAvgPool2d(self.POOL)
        self.layers = nn.Sequential(layer1, layer2, layer3, pool)

        self.dropout = nn.Dropout(p=dropout)
        self.nn = nn.Sequential(
                    nn.Linear(self.POOL[0] * self.POOL[1] * self.CHANNELS[2], 64),
                    nn.Linear(64, num_classes)
                )

    def forward(self, x, return_feats=False):
        feats = self.layers(x)
        feats = feats.flatten(1)
        x = self.nn(self.dropout(feats))

        if return_feats:
            return x, feats

        return x

In [64]:
x = torch.rand(100, 3, 64, 64)
net = ConvNet(3, 9)
y = net(x)
print (y.shape)

torch.Size([100, 9])


In [68]:
def train_model(model1, model2, train_set, val_set, tqdm_on, id, num_epochs, batch_size, learning_rate, c1, c2, t):
    # training side
    optimizer = torch.optim.AdamW(params=list(model1.parameters()) + list(model2.parameters()),
                                  lr=learning_rate, weight_decay=3e-4)
    criterion = nn.CrossEntropyLoss()
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

    # load the training data
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True,
                              num_workers=16, pin_memory=True, drop_last=True)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=16,
                            pin_memory=True, drop_last=True)

    # training loop
    for epoch in range(num_epochs):
        loss1_model1 = AverageMeter()
        loss1_model2 = AverageMeter()
        loss2_model1 = AverageMeter()
        loss2_model2 = AverageMeter()
        loss3_combined = AverageMeter()
        acc_model1 = AverageMeter()
        acc_model2 = AverageMeter()

        model1.train()
        model2.train()
        pg = tqdm(train_loader, leave=False, total=len(train_loader), disable=not tqdm_on)
        for i, (x1, y1, x2, y2) in enumerate(pg):
            # train model1 (doodle)
            pred1, feats1 = model1(x1, return_feats=True)
            loss_1 = criterion(pred1, y1)    # classification loss
            loss_2 = compute_contrastive_loss_from_feats(feats1, y1, t)
            loss1_model1.update(loss_1)
            loss2_model1.update(loss_2)
            loss_model1 = loss_1 + c1 * loss_2

            # train model2 (real)
            pred2, feats2 = model2(x2, return_feats=True)
            loss_1 = criterion(pred2, y2)   # classification loss
            loss_2 = compute_contrastive_loss_from_feats(feats2, y2, t)
            loss1_model2.update(loss_1)
            loss2_model2.update(loss_2)
            loss_model2 = loss_1 + c1 * loss_2

            # the third loss
            combined_feat = feats1 * feats2
            loss_3 = compute_contrastive_loss_from_feats(combined_feat, y1, t)
            loss3_combined.update(loss_3)

            loss = loss_model1 + loss_model2 + c2 * loss_3

            # statistics
            acc_model1.update(compute_accuracy(pred1, y1))
            acc_model2.update(compute_accuracy(pred2, y2))

            # optimization
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            # display
            pg.set_postfix({
                'acc 1': '{:.6f}'.format(acc_model1.avg),
                'acc 2': '{:.6f}'.format(acc_model2.avg),
                'l1m1': '{:.6f}'.format(loss1_model1.avg),
                'l2m1': '{:.6f}'.format(loss2_model1.avg),
                'l1m2': '{:.6f}'.format(loss1_model2.avg),
                'l2m2': '{:.6f}'.format(loss2_model2.avg),
                'train epoch': '{:03d}'.format(epoch)
            })

        print(f'train epoch {epoch}, acc 1={acc_model1.avg:.3f}, acc 2={acc_model2.avg:.3f}, l1m1={loss1_model1.avg:.3f},'
              f'l1m2={loss1_model2.avg:.3f}, l2m1={loss2_model1.avg:.3f}, l2m2={loss2_model2.avg:.3f}, '
              f'l3={loss3_combined.avg:.3f}')

        # validation
        model1.eval(), model1.eval()
        acc_model1.reset(), acc_model2.reset()
        pg = tqdm(val_loader, leave=False, total=len(val_loader), disable=not tqdm_on)
        
        with torch.no_grad():
            for i, (x1, y1, x2, y2) in enumerate(pg):
                pred1, feats1 = model1(x1, return_feats=True)
                pred2, feats2 = model2(x2, return_feats=True)
                acc_model1.update(compute_accuracy(pred1, y1))
                acc_model2.update(compute_accuracy(pred2, y2))

                # display
                pg.set_postfix({
                    'acc 1': '{:.6f}'.format(acc_model1.avg),
                    'acc 2': '{:.6f}'.format(acc_model2.avg),
                    'val epoch': '{:03d}'.format(epoch)
                })

        print(f'validation epoch {epoch}, acc 1 (doodle) = {acc_model1.avg:.3f}, acc 2 (real) = {acc_model2.avg:.3f}')

        scheduler.step()

    print(f'training finished')

    # save checkpoint
    exp_dir = f'exp_data/{id}'
    save_model(exp_dir, f'{id}_model1.pt', model1)
    save_model(exp_dir, f'{id}_model2.pt', model2)

In [None]:
train_set = ImageDataset(doodles, reals, doodle_size, real_size, train=True)
val_set = ImageDataset(doodles, reals, doodle_size, real_size, train=False)

# tunable hyper params.
num_epochs, base_bs, base_lr = 10, 512, 2e-2
c1, c2, t = 1, 1, 0.1  # contrastive learning. if you want vanilla (cross-entropy) training, set c1 and c2 to 0.
dropout = 0.2

# models
doodle_model = ConvNet(1, NUM_CLASSES, dropout)
real_model = ConvNet(3, NUM_CLASSES, dropout)

# just some logistics
tqdm_on = True     # progress bar
id = 25              # change to the id of each experiment accordingly

train_model(doodle_model, real_model, train_set, val_set, tqdm_on, id, num_epochs, base_bs, base_lr, c1, c2, t)

Train = True. Doodle list: ['sketchy_doodle', 'tuberlin', 'google_doodles'], 
 real list: ['sketchy_real', 'google_real', 'cifar']. 
 classes: dict_keys(['airplane', 'car', 'cat', 'dog', 'frog', 'horse', 'truck', 'bird', 'ship']) 
Doodle data size 7022, real data size 46364, ratio 0.15145371408851696
Train = False. Doodle list: ['sketchy_doodle', 'tuberlin', 'google_doodles'], 
 real list: ['sketchy_real', 'google_real', 'cifar']. 
 classes: dict_keys(['airplane', 'car', 'cat', 'dog', 'frog', 'horse', 'truck', 'bird', 'ship']) 
Doodle data size 1764, real data size 9341, ratio 0.18884487742211756


                                                                                                                                                            

train epoch 0, acc 1=0.194, acc 2=0.329, l1m1=2.066,l1m2=1.721, l2m1=2.206, l2m2=2.043, l3=2.023


                                                                                              

validation epoch 0, acc 1 (doodle) = 0.185, acc 2 (real) = 0.323


                                                                                                                                                            

train epoch 1, acc 1=0.278, acc 2=0.453, l1m1=1.839,l1m2=1.431, l2m1=2.068, l2m2=1.836, l3=1.690


                                                                                              

validation epoch 1, acc 1 (doodle) = 0.273, acc 2 (real) = 0.441


                                                                                                                                                            

train epoch 2, acc 1=0.351, acc 2=0.525, l1m1=1.713,l1m2=1.269, l2m1=2.018, l2m2=1.739, l3=1.531


                                                                                              

validation epoch 2, acc 1 (doodle) = 0.252, acc 2 (real) = 0.476


                                                                                                                                                            

train epoch 3, acc 1=0.399, acc 2=0.569, l1m1=1.604,l1m2=1.166, l2m1=1.971, l2m2=1.669, l3=1.426


                                                                                              

validation epoch 3, acc 1 (doodle) = 0.393, acc 2 (real) = 0.506


                                                                                                                                                            

train epoch 4, acc 1=0.449, acc 2=0.592, l1m1=1.493,l1m2=1.105, l2m1=1.922, l2m2=1.628, l3=1.340


                                                                                              

validation epoch 4, acc 1 (doodle) = 0.284, acc 2 (real) = 0.515


                                                                                                                                                            

train epoch 5, acc 1=0.495, acc 2=0.626, l1m1=1.377,l1m2=1.026, l2m1=1.858, l2m2=1.565, l3=1.237


                                                                                              

validation epoch 5, acc 1 (doodle) = 0.421, acc 2 (real) = 0.543


 99%|█████████▉| 89/90 [12:45<00:08,  8.27s/it, acc 1=0.528946, acc 2=0.647231, l1m1=1.288378, l2m1=1.809368, l1m2=0.983193, l2m2=1.537351, train epoch=006]