In [1]:
from torch.utils import data
from torchvision import datasets, transforms,models

#计算数据集的mean和std
def get_dataset_mean_and_std(directory):
    dataset = datasets.ImageFolder(
        directory,
        transform=transforms.Compose([
            transforms.ToTensor()
        ])
    )

    data_loader = data.DataLoader(dataset)

    mean = [0, 0, 0]
    std = [0, 0, 0]
    for channel in range(3):
        _mean = 0
        _std = 0
        for _, (xs, _) in enumerate(data_loader):
            img = xs[0][channel].numpy()
            _mean += img.mean()
            _std += img.std()

        mean[channel] = _mean/len(dataset)
        std[channel] = _std/len(dataset)
    return mean, std

In [2]:
#保存数据及网络

import pickle

def save(obj, path):
    with open(path, 'wb') as f:
        pickle.dump(obj, f)
        print('[INFO] Object saved to {}'.format(path))
def load_s(path):
    with open(path, 'rb') as f:
        a=pickle.load(f)
        print('[INFO] Object loaded {}'.format(path))
        for i in a:
            print i
def save_net(model, path):
    torch.save(model.state_dict(), path)
    print('[INFO] Checkpoint saved to {}'.format(path))

def load_net(model, path):
    model.load_state_dict(torch.load(path))
    print('[INFO] Checkpoint {} loaded'.format(path))

In [3]:
#展示data_loader中一个batch的图像
import numpy as np
import torchvision
%matplotlib inline
import matplotlib.pyplot as plt

def imshow(inp):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1,2,0))
    #inp = inp.numpy()
    inp = np.uint8(inp)
    plt.imshow(inp)

'''
# Get a batch of training data
inputs, classes = next(iter(data_loader))
# Make a grid from batch
out = torchvision.utils.make_grid(inputs)

imshow(out)
plt.title([x for x in classes])
plt.show()
'''

'\n# Get a batch of training data\ninputs, classes = next(iter(data_loader))\n# Make a grid from batch\nout = torchvision.utils.make_grid(inputs)\n\nimshow(out)\nplt.title([x for x in classes])\nplt.show()\n'

In [4]:
#MMD的pytorch实现

import torch

def _mix_rbf_kernel(X, Y, sigma_list):
    assert(X.size(0) == Y.size(0))
    m = X.size(0)

    Z = torch.cat((X, Y), 0)
    ZZT = torch.mm(Z, Z.t())
    diag_ZZT = torch.diag(ZZT).unsqueeze(1)
    Z_norm_sqr = diag_ZZT.expand_as(ZZT)
    exponent = Z_norm_sqr - 2 * ZZT + Z_norm_sqr.t()

    K = 0.0
    for sigma in sigma_list:
        gamma = 1.0 / (2 * sigma**2)
        K += torch.exp(-gamma * exponent)

    return K[:m, :m], K[:m, m:], K[m:, m:], len(sigma_list)

def mix_rbf_mmd2(X, Y, sigma_list, biased=True):
    K_XX, K_XY, K_YY, d = _mix_rbf_kernel(X, Y, sigma_list)
    # return _mmd2(K_XX, K_XY, K_YY, const_diagonal=d, biased=biased)
    return _mmd2(K_XX, K_XY, K_YY, const_diagonal=False, biased=biased)

def _mmd2(K_XX, K_XY, K_YY, const_diagonal=False, biased=False):
    m = K_XX.size(0)    # assume X, Y are same shape

    # Get the various sums of kernels that we'll use
    # Kts drop the diagonal, but we don't need to compute them explicitly
    if const_diagonal is not False:
        diag_X = diag_Y = const_diagonal
        sum_diag_X = sum_diag_Y = m * const_diagonal
    else:
        diag_X = torch.diag(K_XX)                       # (m,)
        diag_Y = torch.diag(K_YY)                       # (m,)
        sum_diag_X = torch.sum(diag_X)
        sum_diag_Y = torch.sum(diag_Y)

    Kt_XX_sums = K_XX.sum(dim=1) - diag_X             # \tilde{K}_XX * e = K_XX * e - diag_X
    Kt_YY_sums = K_YY.sum(dim=1) - diag_Y             # \tilde{K}_YY * e = K_YY * e - diag_Y
    K_XY_sums_0 = K_XY.sum(dim=0)                     # K_{XY}^T * e

    Kt_XX_sum = Kt_XX_sums.sum()                       # e^T * \tilde{K}_XX * e
    Kt_YY_sum = Kt_YY_sums.sum()                       # e^T * \tilde{K}_YY * e
    K_XY_sum = K_XY_sums_0.sum()                       # e^T * K_{XY} * e

    if biased:
        mmd2 = ((Kt_XX_sum + sum_diag_X) / (m * m)
            + (Kt_YY_sum + sum_diag_Y) / (m * m)
            - 2.0 * K_XY_sum / (m * m))
    else:
        mmd2 = (Kt_XX_sum / (m * (m - 1))
            + Kt_YY_sum / (m * (m - 1))
            - 2.0 * K_XY_sum / (m * m))

    return mmd2

'''
sigma_list=[1,2,4,8,16]
#print(out1.data, out2.data)
if out1.size(0)==out2.size(0):
    mmd2_D = mix_rbf_mmd2(out1,out2, sigma_list)
    mmd2_D = F.relu(mmd2_D)
    mmd_loss=mmd2_D
'''

'\nsigma_list=[1,2,4,8,16]\n#print(out1.data, out2.data)\nif out1.size(0)==out2.size(0):\n    mmd2_D = mix_rbf_mmd2(out1,out2, sigma_list)\n    mmd2_D = F.relu(mmd2_D)\n    mmd_loss=mmd2_D\n'

In [5]:
#载入与训练模型参数
# load AlexNet pre-trained model
from torch.utils import model_zoo

def load_pretrained(model):
    url = 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth'
    pretrained_dict = model_zoo.load_url(url)
    model_dict = model.state_dict()

    # filter out unmatch dict and delete last fc bias, weight
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    # del pretrained_dict['classifier.6.bias']
    # del pretrained_dict['classifier.6.weight']

    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)


In [6]:
#优化函数
'''
optimizer = torch.optim.SGD([
        {'params': model.sharedNet.parameters()},
        {'params': model.source_fc.parameters(), 'lr': 10*LEARNING_RATE},
        {'params': model.target_fc.parameters(), 'lr': 10*LEARNING_RATE}
    ], lr=LEARNING_RATE, momentum=MOMENTUM)
'''

"\noptimizer = torch.optim.SGD([\n        {'params': model.sharedNet.parameters()},\n        {'params': model.source_fc.parameters(), 'lr': 10*LEARNING_RATE},\n        {'params': model.target_fc.parameters(), 'lr': 10*LEARNING_RATE}\n    ], lr=LEARNING_RATE, momentum=MOMENTUM)\n"