In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from random import randint
import time
import utils
import numpy as np
from torch.utils.data import Dataset, TensorDataset, DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import cv2
import random

In [2]:
!pip install tqdm

Collecting tqdm
  Downloading tqdm-4.64.0-py2.py3-none-any.whl (78 kB)
[K     |████████████████████████████████| 78 kB 9.7 MB/s  eta 0:00:01
[?25hInstalling collected packages: tqdm
Successfully installed tqdm-4.64.0


In [4]:
import torch.utils.data.dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
from losses import compute_contrastive_loss_from_feats
from utils import *  # bad practice, nvm
from models import *

from dataset import ImageDataset
from training_config import doodles, reals, doodle_size, real_size, NUM_CLASSES

In [5]:
ckpt_dir = 'exp_data'

In [15]:
def convbn(in_channels, out_channels, kernel_size, stride, padding, bias):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
    )

class V2ConvNet(nn.Module):
    CHANNELS = [64, 128, 192, 256, 512]
    POOL = (1, 1)

    def __init__(self, in_c, num_classes, dropout=0.2, add_layers=False):
        super().__init__()
        layer1 = convbn(in_c, self.CHANNELS[1], kernel_size=3, stride=2, padding=1, bias=True)
        layer2 = convbn(self.CHANNELS[1], self.CHANNELS[2], kernel_size=3, stride=2, padding=1, bias=True)
        layer3 = convbn(self.CHANNELS[2], self.CHANNELS[3], kernel_size=3, stride=2, padding=1, bias=True)
        layer4 = convbn(self.CHANNELS[3], self.CHANNELS[4], kernel_size=3, stride=2, padding=1, bias=True)
        pool = nn.AdaptiveAvgPool2d(self.POOL)
        self.layers = nn.Sequential(layer1, layer2, layer3, layer4, pool)

        if add_layers:
            layer1_2 = convbn(self.CHANNELS[1], self.CHANNELS[1], kernel_size=3, stride=1, padding=0, bias=True)
            layer2_2 = convbn(self.CHANNELS[2], self.CHANNELS[2], kernel_size=3, stride=1, padding=0, bias=True)
            layer3_2 = convbn(self.CHANNELS[3], self.CHANNELS[3], kernel_size=3, stride=1, padding=0, bias=True)
            layer4_2 = convbn(self.CHANNELS[4], self.CHANNELS[4], kernel_size=3, stride=1, padding=0, bias=True)
            self.layers = nn.Sequential(layer1, layer1_2, layer2, layer2_2, layer3, layer3_2, layer4, layer4_2, pool)

        self.nn = nn.Linear(self.POOL[0] * self.POOL[1] * self.CHANNELS[4], num_classes)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x, return_feats=False):
        feats = self.layers(x).flatten(1)
        x = self.nn(self.dropout(feats))

        if return_feats:
            return x, feats

        return x

In [16]:
x = torch.rand(100, 3, 64, 64)
net = V2ConvNet(3, 9, add_layers=True)
y = net(x)
print (y.shape)

torch.Size([100, 9])


In [8]:
def train_model(model1, model2, train_set, val_set, tqdm_on, id, num_epochs, batch_size, learning_rate, c1, c2, t):
    # cuda side setup
    model1 = nn.DataParallel(model1).cuda()
    model2 = nn.DataParallel(model2).cuda()

    # training side
    optimizer = torch.optim.AdamW(params=list(model1.parameters()) + list(model2.parameters()),
                                  lr=learning_rate, weight_decay=3e-4)
    criterion = nn.CrossEntropyLoss()
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

    # load the training data
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True,
                              num_workers=16, pin_memory=True, drop_last=True)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=16,
                            pin_memory=True, drop_last=True)

    # training loop
    for epoch in range(num_epochs):
        loss1_model1 = AverageMeter()
        loss1_model2 = AverageMeter()
        loss2_model1 = AverageMeter()
        loss2_model2 = AverageMeter()
        loss3_combined = AverageMeter()
        acc_model1 = AverageMeter()
        acc_model2 = AverageMeter()

        model1.train()
        model2.train()
        pg = tqdm(train_loader, leave=False, total=len(train_loader), disable=not tqdm_on)
        for i, (x1, y1, x2, y2) in enumerate(pg):
            # doodle, label, real, label
            x1, y1, x2, y2 = x1.cuda(), y1.cuda(), x2.cuda(), y2.cuda()

            # train model1 (doodle)
            pred1, feats1 = model1(x1, return_feats=True)
            loss_1 = criterion(pred1, y1)    # classification loss
            loss_2 = compute_contrastive_loss_from_feats(feats1, y1, t)
            loss1_model1.update(loss_1)
            loss2_model1.update(loss_2)
            loss_model1 = loss_1 + c1 * loss_2

            # train model2 (real)
            pred2, feats2 = model2(x2, return_feats=True)
            loss_1 = criterion(pred2, y2)   # classification loss
            loss_2 = compute_contrastive_loss_from_feats(feats2, y2, t)
            loss1_model2.update(loss_1)
            loss2_model2.update(loss_2)
            loss_model2 = loss_1 + c1 * loss_2

            # the third loss
            combined_feat = feats1 * feats2
            loss_3 = compute_contrastive_loss_from_feats(combined_feat, y1, t)
            loss3_combined.update(loss_3)

            loss = loss_model1 + loss_model2 + c2 * loss_3

            # statistics
            acc_model1.update(compute_accuracy(pred1, y1))
            acc_model2.update(compute_accuracy(pred2, y2))

            # optimization
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            # display
            pg.set_postfix({
                'acc 1': '{:.6f}'.format(acc_model1.avg),
                'acc 2': '{:.6f}'.format(acc_model2.avg),
                'l1m1': '{:.6f}'.format(loss1_model1.avg),
                'l2m1': '{:.6f}'.format(loss2_model1.avg),
                'l1m2': '{:.6f}'.format(loss1_model2.avg),
                'l2m2': '{:.6f}'.format(loss2_model2.avg),
                'train epoch': '{:03d}'.format(epoch)
            })

        print(f'train epoch {epoch}, acc 1={acc_model1.avg:.3f}, acc 2={acc_model2.avg:.3f}, l1m1={loss1_model1.avg:.3f},'
              f'l1m2={loss1_model2.avg:.3f}, l2m1={loss2_model1.avg:.3f}, l2m2={loss2_model2.avg:.3f}, '
              f'l3={loss3_combined.avg:.3f}')

        # validation
        model1.eval(), model1.eval()
        acc_model1.reset(), acc_model2.reset()
        pg = tqdm(val_loader, leave=False, total=len(val_loader), disable=not tqdm_on)
        with torch.no_grad():
            for i, (x1, y1, x2, y2) in enumerate(pg):
                pred1, feats1 = model1(x1, return_feats=True)
                pred2, feats2 = model2(x2, return_feats=True)
                acc_model1.update(compute_accuracy(pred1, y1))
                acc_model2.update(compute_accuracy(pred2, y2))

                # display
                pg.set_postfix({
                    'acc 1': '{:.6f}'.format(acc_model1.avg),
                    'acc 2': '{:.6f}'.format(acc_model2.avg),
                    'val epoch': '{:03d}'.format(epoch)
                })

        print(f'validation epoch {epoch}, acc 1 (doodle) = {acc_model1.avg:.3f}, acc 2 (real) = {acc_model2.avg:.3f}')

        scheduler.step()

    print(f'training finished')

    # save checkpoint
    exp_dir = f'exp_data/{id}'
    save_model(exp_dir, f'{id}_model1.pt', model1)
    save_model(exp_dir, f'{id}_model2.pt', model2)

In [9]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"

In [19]:
fix_seed(0)

train_set = ImageDataset(doodles, reals, doodle_size, real_size, train=True)
val_set = ImageDataset(doodles, reals, doodle_size, real_size, train=False)

# tunable hyper params.
use_cnn = True
num_epochs, base_bs, base_lr = 15, 512, 2e-2
c1, c2, t = 0, 0, 0.1  # contrastive learning. if you want vanilla (cross-entropy) training, set c1 and c2 to 0.
dropout = 0.3

# models
doodle_model = V2ConvNet(1, NUM_CLASSES, dropout, add_layers=True)
real_model = V2ConvNet(3, NUM_CLASSES, dropout, add_layers=True)

# just some logistics
tqdm_on = True     # progress bar
id = 24             # change to the id of each experiment accordingly

train_model(doodle_model, real_model, train_set, val_set, tqdm_on, id, num_epochs, base_bs, base_lr, c1, c2, t)

Train = True. Doodle list: ['sketchy_doodle', 'tuberlin', 'google_doodles'], 
 real list: ['sketchy_real', 'google_real', 'cifar']. 
 classes: dict_keys(['airplane', 'car', 'cat', 'dog', 'frog', 'horse', 'truck', 'bird', 'ship']) 
Doodle data size 7022, real data size 46364, ratio 0.15145371408851696
Train = False. Doodle list: ['sketchy_doodle', 'tuberlin', 'google_doodles'], 
 real list: ['sketchy_real', 'google_real', 'cifar']. 
 classes: dict_keys(['airplane', 'car', 'cat', 'dog', 'frog', 'horse', 'truck', 'bird', 'ship']) 
Doodle data size 1764, real data size 9341, ratio 0.18884487742211756


                                                                                                          

train epoch 0, acc 1=0.286, acc 2=0.337, l1m1=1.987,l1m2=1.845, l2m1=4.619, l2m2=3.753, l3=4.773


                                                                                                          

validation epoch 0, acc 1 (doodle) = 0.408, acc 2 (real) = 0.387


                                                                                                          

train epoch 1, acc 1=0.541, acc 2=0.520, l1m1=1.249,l1m2=1.309, l2m1=3.740, l2m2=3.072, l3=4.500


                                                                                                          

validation epoch 1, acc 1 (doodle) = 0.480, acc 2 (real) = 0.498


                                                                                                          

train epoch 2, acc 1=0.766, acc 2=0.634, l1m1=0.655,l1m2=1.013, l2m1=3.268, l2m2=2.660, l3=4.521


                                                                                                          

validation epoch 2, acc 1 (doodle) = 0.587, acc 2 (real) = 0.564


                                                                                                          

train epoch 3, acc 1=0.930, acc 2=0.712, l1m1=0.211,l1m2=0.806, l2m1=2.950, l2m2=2.329, l3=4.381


                                                                                                          

validation epoch 3, acc 1 (doodle) = 0.613, acc 2 (real) = 0.615


                                                                                                          

train epoch 4, acc 1=0.980, acc 2=0.769, l1m1=0.065,l1m2=0.650, l2m1=2.754, l2m2=2.169, l3=4.306


                                                                                                          

validation epoch 4, acc 1 (doodle) = 0.603, acc 2 (real) = 0.649


                                                                                                          

train epoch 5, acc 1=0.991, acc 2=0.815, l1m1=0.031,l1m2=0.530, l2m1=2.667, l2m2=2.115, l3=4.265


                                                                                                          

validation epoch 5, acc 1 (doodle) = 0.610, acc 2 (real) = 0.678


                                                                                                          

train epoch 6, acc 1=0.996, acc 2=0.857, l1m1=0.013,l1m2=0.411, l2m1=2.613, l2m2=2.014, l3=4.205


                                                                                                          

validation epoch 6, acc 1 (doodle) = 0.640, acc 2 (real) = 0.679


                                                                                                          

train epoch 7, acc 1=0.999, acc 2=0.888, l1m1=0.003,l1m2=0.323, l2m1=2.526, l2m2=1.950, l3=4.130


                                                                                                          

validation epoch 7, acc 1 (doodle) = 0.635, acc 2 (real) = 0.698


                                                                                                          

train epoch 8, acc 1=1.000, acc 2=0.923, l1m1=0.002,l1m2=0.229, l2m1=2.486, l2m2=1.887, l3=4.104


                                                                                                          

validation epoch 8, acc 1 (doodle) = 0.648, acc 2 (real) = 0.699


                                                                                                          

train epoch 9, acc 1=1.000, acc 2=0.949, l1m1=0.001,l1m2=0.155, l2m1=2.443, l2m2=1.778, l3=4.053


                                                                                                          

validation epoch 9, acc 1 (doodle) = 0.648, acc 2 (real) = 0.703


                                                                                                          

train epoch 10, acc 1=1.000, acc 2=0.967, l1m1=0.001,l1m2=0.099, l2m1=2.438, l2m2=1.721, l3=4.031


                                                                                                          

validation epoch 10, acc 1 (doodle) = 0.652, acc 2 (real) = 0.703


                                                                                                          

train epoch 11, acc 1=1.000, acc 2=0.982, l1m1=0.001,l1m2=0.060, l2m1=2.430, l2m2=1.655, l3=3.990


                                                                                                          

validation epoch 11, acc 1 (doodle) = 0.652, acc 2 (real) = 0.711


                                                                                                          

train epoch 12, acc 1=1.000, acc 2=0.992, l1m1=0.001,l1m2=0.032, l2m1=2.410, l2m2=1.597, l3=3.954


                                                                                                          

validation epoch 12, acc 1 (doodle) = 0.653, acc 2 (real) = 0.720


                                                                                                          

train epoch 13, acc 1=1.000, acc 2=0.995, l1m1=0.000,l1m2=0.025, l2m1=2.415, l2m2=1.558, l3=3.942


                                                                                                          

validation epoch 13, acc 1 (doodle) = 0.647, acc 2 (real) = 0.719


                                                                                                          

train epoch 14, acc 1=1.000, acc 2=0.997, l1m1=0.001,l1m2=0.019, l2m1=2.413, l2m2=1.553, l3=3.938


                                                                                                          

validation epoch 14, acc 1 (doodle) = 0.644, acc 2 (real) = 0.716
training finished
Model saved: exp_data/24/24_model1.pt
Model saved: exp_data/24/24_model2.pt


