In [None]:
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
import os
import numpy as np
from torchvision.datasets import MNIST, SVHN
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt

In [None]:
batch_size = 64
test_batch_size = 1000
epochs = 100
lr = 0.03
momentum = 0.5
seed = 40
cuda = True
log_interval = 100
random = False
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

In [None]:
def calc_coeff(iter_num, high=1.0, low=0.0, alpha=10.0, max_iter=10000.0):
    return float(2.0 * (high - low) / (1.0 + np.exp(-alpha * iter_num / max_iter)) - (high - low) + low)

In [None]:
def init_weights(m):
    classname = m.__class__.__name__
    if classname.find('Conv2d') != -1 or classname.find('ConvTranspose2d') != -1:
        nn.init.kaiming_uniform_(m.weight)
        nn.init.zeros_(m.bias)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight, 1.0, 0.02)
        nn.init.zeros_(m.bias)
    elif classname.find('Linear') != -1:
        nn.init.xavier_normal_(m.weight)
        # nn.init.zeros_(m.bias)
        m.bias.data.fill_(0)

In [None]:
def Entropy(input_):
    epsilon = 1e-5
    entropy = -input_ * torch.log(input_ + epsilon)
    entropy = torch.sum(entropy, dim=1)
    return entropy


def grl_hook(coeff):
    def fun1(grad):
        return -coeff * grad.clone()

    return fun1


def CDAN(input_list, ad_net, entropy=None, coeff=None, random_layer=None):
    softmax_output = input_list[1].detach()
    feature = input_list[0]
    if random_layer is None:
        op_out = torch.bmm(softmax_output.unsqueeze(2), feature.unsqueeze(1))
        ad_out = ad_net(op_out.view(-1, softmax_output.size(1) * feature.size(1)))
    else:
        random_out = random_layer.forward([feature, softmax_output])
        ad_out = ad_net(random_out.view(-1, random_out.size(1)))
    batch_size = softmax_output.size(0) // 2
    dc_target = torch.from_numpy(np.array([[1]] * batch_size + [[0]] * batch_size)).float()
    if cuda:
        dc_target = dc_target.cuda()
    if entropy is not None:
        entropy.register_hook(grl_hook(coeff))
        entropy = 1.0 + torch.exp(-entropy)
        source_mask = torch.ones_like(entropy)
        source_mask[feature.size(0) // 2:] = 0
        source_weight = entropy * source_mask
        target_mask = torch.ones_like(entropy)
        target_mask[0:feature.size(0) // 2] = 0
        target_weight = entropy * target_mask
        weight = source_weight / torch.sum(source_weight).detach().item() + \
                 target_weight / torch.sum(target_weight).detach().item()
        l = nn.BCELoss(reduction='none')(ad_out, dc_target)
        return torch.sum(weight.view(-1, 1) * nn.BCELoss()(ad_out, dc_target)) / torch.sum(weight).detach().item()
    else:
        return nn.BCELoss()(ad_out, dc_target)


def mdd_loss(features, labels, left_weight=1, right_weight=1):
    softmax_out = nn.Softmax(dim=1)(features)
    batch_size = features.size(0)
    if float(batch_size) % 2 != 0:
        raise Exception('Incorrect batch size provided')

    batch_left = softmax_out[:int(0.5 * batch_size)]
    batch_right = softmax_out[int(0.5 * batch_size):]

    loss = torch.norm((batch_left - batch_right).abs(), 2, 1).sum() / float(batch_size)

    labels_left = labels[:int(0.5 * batch_size)]
    batch_left_loss = get_pari_loss1(labels_left, batch_left)

    labels_right = labels[int(0.5 * batch_size):]
    batch_right_loss = get_pari_loss1(labels_right, batch_right)
    return loss + left_weight * batch_left_loss + right_weight * batch_right_loss


def mdd_digit(features, labels, left_weight=1, right_weight=1, weight=1):
    softmax_out = nn.Softmax(dim=1)(features)
    batch_size = features.size(0)
    if float(batch_size) % 2 != 0:
        raise Exception('Incorrect batch size provided')

    batch_left = softmax_out[:int(0.5 * batch_size)]
    batch_right = softmax_out[int(0.5 * batch_size):]

    loss = torch.norm((batch_left - batch_right).abs(), 2, 1).sum() / float(batch_size)

    labels_left = labels[:int(0.5 * batch_size)]
    labels_left_left = labels_left[:int(0.25 * batch_size)]
    labels_left_right = labels_left[int(0.25 * batch_size):]

    batch_left_left = batch_left[:int(0.25 * batch_size)]
    batch_left_right = batch_left[int(0.25 * batch_size):]
    batch_left_loss = get_pair_loss(labels_left_left, labels_left_right, batch_left_left, batch_left_right)

    labels_right = labels[int(0.5 * batch_size):]
    labels_right_left = labels_right[:int(0.25 * batch_size)]
    labels_right_right = labels_right[int(0.25 * batch_size):]

    batch_right_left = batch_right[:int(0.25 * batch_size)]
    batch_right_right = batch_right[int(0.25 * batch_size):]
    batch_right_loss = get_pair_loss(labels_right_left, labels_right_right, batch_right_left, batch_right_right)

    return weight*loss + left_weight * batch_left_loss + right_weight * batch_right_loss


In [None]:
def normalize(data_tensor):
    '''re-scale image values to [-1, 1]'''
    return (data_tensor / 255.) * 2. - 1. 

def tile_image(image):
    print(np.array(image).max(), np.array(image).min())
    return image #image.repeat(3,1,1)

transform_list = [transforms.ToTensor(), transforms.Lambda(lambda x: normalize(x))]

mnist_dataset = MNIST(root="data", train=True, download=True,
    transform=transforms.Compose([         
        transforms.Grayscale(3),
        transforms.Resize(32),
        transforms.ToTensor(),
        transforms.Normalize([0.1307, 0.1307, 0.1307], [0.3081, 0.3081, 0.3081])
        ]))

mnist_test_dataset = MNIST(root="data", train=False, download=True,
    transform=transforms.Compose([         
        transforms.Grayscale(3),
        transforms.Resize(32),
        transforms.ToTensor(),
        transforms.Normalize([0.1307, 0.1307, 0.1307], [0.3081, 0.3081, 0.3081])
        ]))

svhn_dataset = SVHN(root="data", split='train', download=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.4377, 0.4438, 0.4728], [0.1980, 0.2010, 0.1970]),
    ]))


svhn_test_dataset = SVHN(root="data", split='test', download=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.4377, 0.4438, 0.4728], [0.1980, 0.2010, 0.1970]),
    ]))

svhn_loader = DataLoader(dataset=svhn_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=0)
svhn_test_loader = DataLoader(dataset=svhn_test_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=0)
mnist_loader = DataLoader(dataset=mnist_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=0)
mnist_test_loader = DataLoader(dataset=mnist_test_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=0)

In [None]:
for batch_idx, (data, target) in enumerate(mnist_loader):
    plt.imshow(data[0].permute(1, 2, 0))
    plt.show()
    print(data[0].size())
    break

In [None]:
class DTN(nn.Module):
    def __init__(self):
        super(DTN, self).__init__()
        self.conv_params = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(64),
            nn.Dropout2d(0.1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(128),
            nn.Dropout2d(0.3),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(256),
            nn.Dropout2d(0.5),
            nn.ReLU()
        )

        self.fc_params = nn.Sequential(
            nn.Linear(256 * 4 * 4, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout()
        )

        self.classifier = nn.Linear(512, 10)
        self.__in_features = 512

    def forward(self, x):
        x = self.conv_params(x)
        x = x.view(x.size(0), -1)
        x = self.fc_params(x)
        y = self.classifier(x)
        return x, y

    def output_num(self):
        return self.__in_features

In [None]:

class AdversarialNetwork(nn.Module):
    def __init__(self, in_feature, hidden_size):
        super(AdversarialNetwork, self).__init__()
        self.ad_layer1 = nn.Linear(in_feature, hidden_size)
        self.ad_layer2 = nn.Linear(hidden_size, hidden_size)
        self.ad_layer3 = nn.Linear(hidden_size, 1)
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.dropout1 = nn.Dropout(0.5)
        self.dropout2 = nn.Dropout(0.5)
        self.sigmoid = nn.Sigmoid()
        self.apply(init_weights)
        self.iter_num = 0
        self.alpha = 10
        self.low = 0.0
        self.high = 1.0
        self.max_iter = 10000.0

    def forward(self, x):
        if self.training:
            self.iter_num += 1
        coeff = calc_coeff(self.iter_num, self.high, self.low, self.alpha, self.max_iter)
        x = x * 1.0
        x.register_hook(grl_hook(coeff))
        x = self.ad_layer1(x)
        x = self.relu1(x)
        x = self.dropout1(x)
        x = self.ad_layer2(x)
        x = self.relu2(x)
        x = self.dropout2(x)
        y = self.ad_layer3(x)
        y = self.sigmoid(y)
        return y

    def output_num(self):
        return 1

    def get_parameters(self):
        return [{"params": self.parameters(), "lr_mult": 10, 'decay_mult': 2}]

In [None]:
def train(model, ad_net, random_layer, train_loader, train_loader1, optimizer, optimizer_ad, epoch):
    model.train()
    len_source = len(train_loader)
    len_target = len(train_loader1)
    num_iter = min(len(train_loader), len(train_loader1))
    total_loss = 0
    for batch_idx in range(num_iter):
        if batch_idx % len_source == 0:
            iter_source = iter(train_loader)
        if batch_idx % len_target == 0:
            iter_target = iter(train_loader1)
        data_source, label_source = next(iter_source)
        if cuda:
            data_source, label_source = data_source.cuda(), label_source.cuda()
        data_target, label_target = next(iter_target)
        if cuda:
            data_target = data_target.cuda()
        optimizer.zero_grad()
        optimizer_ad.zero_grad()

        feature_source, output_source = model(data_source)
        feature_target, output_target = model(data_target)

        feature = torch.cat((feature_source, feature_target), 0)
        output = torch.cat((output_source, output_target), 0)

        labels_target_fake = torch.max(nn.Softmax(dim=1)(output_target), 1)[1]
        labels = torch.cat((label_source, labels_target_fake))

        loss = nn.CrossEntropyLoss()(output.narrow(0, 0, data_source.size(0)), label_source)

        softmax_output = nn.Softmax(dim=1)(output)
        if epoch > 0:
            entropy = Entropy(softmax_output)
            loss += CDAN([feature, softmax_output], ad_net, entropy,
                                   calc_coeff(num_iter * (epoch - 0) + batch_idx), random_layer)

        mdd_loss = 0 #mdd_weight * loss_func.mdd_digit(feature, labels)
        loss = loss + mdd_loss

        total_loss += loss.data

        loss.backward()
        optimizer.step()
        if epoch > 0:
            optimizer_ad.step()
        if (batch_idx + epoch * num_iter) % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * batch_size, num_iter * batch_size,
                       100. * batch_idx / num_iter, loss.item()))
    log_str = "total_loss:{}\n".format(total_loss)
    print(log_str)

In [None]:
def test(epoch, model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    for data, target in test_loader:
        if cuda:
            data, target = data.cuda(), target.cuda()
        feature, output = model(data)
        test_loss += nn.CrossEntropyLoss()(output, target).item()
        pred = output.data.cpu().max(1, keepdim=True)[1]
        correct += pred.eq(target.data.cpu().view_as(pred)).sum().item()
    
    test_loss /= len(test_loader.dataset)
    acc = 100. * correct / len(test_loader.dataset)
    log_str = 'epoch:{},Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.4f}%)\n'.format(epoch,
                                                                                            test_loss, correct,
                                                                                            len(test_loader.dataset),
                                                                                            acc)

    print(log_str)
    return acc

In [None]:
model = DTN()
if cuda:
    model = model.cuda()
class_num = 10

if random:
    random_layer = network.RandomLayer([model.output_num(), class_num], 500)
    ad_net = AdversarialNetwork(500, 500)
    if cuda:
        random_layer.cuda()
else:
    random_layer = None
    ad_net = AdversarialNetwork(model.output_num() * class_num, 500)
if cuda:
    ad_net = ad_net.cuda()
optimizer = optim.SGD(model.parameters(), lr=lr, weight_decay=0.0005, momentum=0.9)
optimizer_ad = optim.SGD(ad_net.parameters(), lr=lr, weight_decay=0.0005, momentum=0.9)

best_model = model
best_acc = 0

for epoch in range(1, epochs + 1):
    if epoch % 3 == 0:
        for param_group in optimizer.param_groups:
            param_group["lr"] = param_group["lr"] * 0.3

    train(model, ad_net, random_layer, svhn_loader, mnist_loader, optimizer, optimizer_ad, epoch)
    acc1 = test(epoch, model, svhn_test_loader)
    acc2 = test(epoch, model, mnist_test_loader)
    if (acc2 > best_acc):
        best_model = model
        best_acc = acc2
    if epoch % 10:
        print("Best Acc so far: ", best_acc)


In [None]:
torch.save(best_model, osp.join("snapshot/s2m_model", "s2m_{}".format(str(best_acc))))

In [None]:
torch.save(best_model, 'DTNS-M_91_86.pth')

In [None]:
from IPython.display import FileLink
FileLink(r'DTNS-M_91_86.pth')