In [None]:
from torchvision import datasets, transforms
import torch
from __future__ import print_function
import torch
import torch.nn.functional as F
from torch.autograd import Variable
import os
import math
import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo
import torch.nn.functional as F
import torch
import copy

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
%cd gdrive/My Drive/Deep learning

Dataloader

In [None]:
def load_training(root_path, dir, batch_size, kwargs):
    transform = transforms.Compose(
        [transforms.Resize([256, 256]),
         transforms.RandomCrop(224),
         transforms.RandomHorizontalFlip(),
         transforms.ToTensor()])
    data = datasets.ImageFolder(root=root_path + dir, transform=transform)
    train_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, drop_last=True, **kwargs)
    return train_loader

def load_testing(root_path, dir, batch_size, kwargs):
    transform = transforms.Compose(
        [transforms.Resize([224, 224]),
         transforms.ToTensor()])
    data = datasets.ImageFolder(root=root_path + dir, transform=transform)
    test_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, **kwargs)
    return test_loader

# source classifier (include source feature extractor)

In [None]:
class SourceClassifer(nn.Module):
    def __init__(self, f_dim=256, n_classes=65):
        super(SourceClassifer, self).__init__()

        self.f_dim = f_dim
        self.n_classes = n_classes

        # Get ResNet50 model
        ResNet50 = torch.hub.load('pytorch/vision:v0.6.0', 'resnet50', pretrained=True)
        ResNet50.fc = nn.Identity()
        self.ResNet50 = ResNet50

        self.extractor1 = nn.Sequential(
            nn.Linear(2048, 1024),
            nn.ELU(),
            nn.Linear(1024, 1024),
            nn.BatchNorm1d(1024),  # expect 2-D input
            nn.ELU(),
            nn.Linear(1024, self.f_dim),
            nn.ELU(),
            nn.Linear(self.f_dim, self.f_dim),
            nn.BatchNorm1d(self.f_dim),
            nn.ELU()
        )

        self.extractor2 = nn.Sequential(
            nn.Linear(2048, 1024),
            nn.ELU(),
            nn.Linear(1024, 1024),
            nn.BatchNorm1d(1024),  # expect 2-D input
            nn.ELU(),
            nn.Linear(1024, self.f_dim),
            nn.ELU(),
            nn.Linear(self.f_dim, self.f_dim),
            nn.BatchNorm1d(self.f_dim),
            nn.ELU()
        )

        self.extractor3 = nn.Sequential(
            nn.Linear(2048, 1024),
            nn.ELU(),
            nn.Linear(1024, 1024),
            nn.BatchNorm1d(1024),  # expect 2-D input
            nn.ELU(),
            nn.Linear(1024, self.f_dim),
            nn.ELU(),
            nn.Linear(self.f_dim, self.f_dim),
            nn.BatchNorm1d(self.f_dim),
            nn.ELU()
        )

        self.cls1 = nn.Linear(self.f_dim, self.n_classes)
        self.cls2 = nn.Linear(self.f_dim, self.n_classes)
        self.cls3 = nn.Linear(self.f_dim, self.n_classes)

    def forward(self, data_src, label_src=0, mark=1, training=True):
        
        if training == True:
            h1 = self.ResNet50(data_src)
            h1 = torch.flatten(h1, start_dim=1)  # size: (batch_size, dim)

            if mark == 1:
                feature1 = self.extractor1(h1)
                pred1 = self.cls1(feature1)

                cls_loss = F.cross_entropy(pred1, label_src)

                return cls_loss

            if mark == 2:
                feature2 = self.extractor2(h1)
                pred2 = self.cls2(feature2)

                cls_loss = F.cross_entropy(pred2, label_src)

                return cls_loss

            if mark == 3:
                feature3 = self.extractor3(h1)
                pred3 = self.cls3(feature3)

                cls_loss = F.cross_entropy(pred3, label_src)

                return cls_loss

        else:
            h1 = self.ResNet50(data_src)
            h1 = torch.flatten(h1, start_dim=1)  # size: (batch_size, dim)

            feature1 = self.extractor1(h1)
            pred1 = self.cls1(feature1)

            feature2 = self.extractor2(h1)
            pred2 = self.cls2(feature2)

            feature3 = self.extractor3(h1)
            pred3 = self.cls3(feature3)

            return pred1, pred2, pred3, feature1, feature2, feature3


train and test function

In [None]:
batch_size = 16
iteration = 6000 // 16
epoch = 10
cuda = True
seed = 8
log_interval = 20
class_num = 65
root_path = "./Dataset/"
source1_name = "Art"
source2_name = 'Clipart'
source3_name = 'Product'
target_name = "Real World"

torch.manual_seed(seed)
if cuda:
    torch.cuda.manual_seed(seed)

kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}

def train(model):
    source1_loader = load_training(root_path, source1_name, batch_size, kwargs)
    source2_loader = load_training(root_path, source2_name, batch_size, kwargs)
    source3_loader = load_training(root_path, source3_name, batch_size, kwargs)

    source1_iter = iter(source1_loader)
    source2_iter = iter(source2_loader)
    source3_iter = iter(source3_loader)

    for i in range(1, iteration + 1):
        model.train()
        LEARNING_RATE_RES = 1e-5
        LEARNING_RATE = 1e-4
        optimizer = torch.optim.Adam([
            {'params': model.ResNet50.parameters(), 'lr': LEARNING_RATE_RES},
            {'params': model.extractor1.parameters(), 'lr': LEARNING_RATE},
            {'params': model.extractor2.parameters(), 'lr': LEARNING_RATE},
            {'params': model.extractor3.parameters(), 'lr': LEARNING_RATE},
            {'params': model.cls1.parameters(), 'lr': LEARNING_RATE},
            {'params': model.cls2.parameters(), 'lr': LEARNING_RATE},
            {'params': model.cls3.parameters(), 'lr': LEARNING_RATE},
        ])

        try:
            source_data, source_label = source1_iter.next()
        except Exception as err:
            source1_iter = iter(source1_loader)
            source_data, source_label = source1_iter.next()
        
        if cuda:
            source_data, source_label = source_data.cuda(), source_label.reshape(-1).cuda()
            
        source_data, source_label = Variable(source_data), Variable(source_label)
        optimizer.zero_grad()

        cls_loss = model(source_data, source_label, mark=1)
        loss = cls_loss
        loss.backward()
        optimizer.step()

        if i % log_interval == 0:
            print('Train source1 iter: {} [({:.0f}%)]\tLoss: {:.6f}'.format(
                i, 100. * i / iteration, loss.item()))

        try:
            source_data, source_label = source2_iter.next()
        except Exception as err:
            source2_iter = iter(source2_loader)
            source_data, source_label = source2_iter.next()
      
        if cuda:
            source_data, source_label = source_data.cuda(), source_label.reshape(-1).cuda()

        source_data, source_label = Variable(source_data), Variable(source_label)
        optimizer.zero_grad()

        cls_loss = model(source_data, source_label, mark=2)
        loss = cls_loss
        loss.backward()
        optimizer.step()

        if i % log_interval == 0:
            print('Train source2 iter: {} [({:.0f}%)]\tLoss: {:.6f}'.format(
                i, 100. * i / iteration, loss.item()))

        try:
            source_data, source_label = source3_iter.next()
        except Exception as err:
            source3_iter = iter(source3_loader)
            source_data, source_label = source3_iter.next()

        if cuda:
            source_data, source_label = source_data.cuda(), source_label.reshape(-1).cuda()

        source_data, source_label = Variable(source_data), Variable(source_label)
        optimizer.zero_grad()

        cls_loss = model(source_data, source_label, mark=3)
        loss = cls_loss
        loss.backward()
        optimizer.step()

        if i % log_interval == 0:
            print('Train source3 iter: {} [({:.0f}%)]\tLoss: {:.6f}'.format(
                i, 100. * i / iteration, loss.item()))
            
    return model

def test(model):
    model.eval()
    test_loss = 0
    correct = 0
    correct1 = 0
    correct2 = 0
    correct3 = 0
    target_test_loader = load_testing(root_path, target_name, batch_size, kwargs)
    with torch.no_grad():
        for data, target in target_test_loader:
            if cuda:
                data, target = data.cuda(), target.cuda()
            data, target = Variable(data), Variable(target)
            pred1, pred2, pred3, _, _, _ = model(data, training=False)

            # pred1 = torch.nn.functional.softmax(pred1, dim=1)
            # pred2 = torch.nn.functional.softmax(pred2, dim=1)
            # pred3 = torch.nn.functional.softmax(pred3, dim=1)

            pred = (pred1 + pred2 + pred3) / 3
            test_loss += F.cross_entropy(pred, target).item()  # sum up batch loss
            pred = pred.data.max(1)[1]  # get the index of the max log-probability
            correct += pred.eq(target.data.view_as(pred)).cpu().sum()
            pred = pred1.data.max(1)[1]  # get the index of the max log-probability
            correct1 += pred.eq(target.data.view_as(pred)).cpu().sum()
            pred = pred2.data.max(1)[1]  # get the index of the max log-probability
            correct2 += pred.eq(target.data.view_as(pred)).cpu().sum()
            pred = pred3.data.max(1)[1]  # get the index of the max log-probability
            correct3 += pred.eq(target.data.view_as(pred)).cpu().sum()

        test_loss /= len(target_test_loader.dataset)
        print(target_name, '\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            test_loss, correct, len(target_test_loader.dataset),
            100. * correct / len(target_test_loader.dataset)))
        print('\nsource1 accnum {}, source2 accnum {}，source3 accnum {}'.format(correct1, correct2, correct3))
    return correct

if __name__ == '__main__':
    model = SourceClassifer(n_classes=class_num)
    best_model_wts = copy.deepcopy(model.state_dict())
    if cuda:
        model.cuda()
    correct = 0
    for _ in range(epoch):
        model = train(model)
        t_correct = test(model)
        if t_correct > correct:
            correct = t_correct
            best_model_wts = copy.deepcopy(model.state_dict())
        print(source1_name, source2_name, source3_name, "to", target_name, "%s max correct:" % target_name, correct.item(), "\n")

In [None]:
torch.save(best_model_wts, f"./results/source_classifer_ACP")

In [None]:
sourceClassifier = SourceClassifer(n_classes=class_num)
sourceClassifier.load_state_dict(torch.load("./results/source_classifer_ACP"))
correct = test(model)