In [1]:
"""
Implements WDGRL with clustering
Wasserstein Distance Guided Representation Learning, Shen et al. (2017)
"""

###library loading###
import argparse
import random
import torch
import numpy as np
from torch import nn
import math
import pandas as pd
import os
import sys
sys.path.insert(0, '../../')
import itertools
from torch.autograd import grad
from torch.utils.data import DataLoader,Subset
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor
import torch.nn.functional as F
from tqdm import tqdm, trange
from sklearn.neighbors import KernelDensity
from sklearn.cluster import KMeans
from sklearn.mixture import GaussianMixture
from torchvision import datasets, transforms
from data.mnist_mnistm_data import *
from data.svhn import *
from torch.nn.utils import spectral_norm
from models.model_svhn_mnist import *
import utils.config as config
from utils.helper import *
from geomloss import SamplesLoss
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
K = 10
batch_size = 512
lr = 1e-2
lamb_clf = 1
lamb_wd = 0.5
lamb_centroid = 1
lamb_sntg = 1
LAMBDA = 50
momentum = 0.4
epochs = 100

In [3]:
seed_torch(seed=123)

In [None]:
best_accu_t = 0.0
# Weight update value
weight_update = 1

##initialize models
model_feature_source = Net().to(device)
model_feature_target = Net().to(device)
model_clf = Classifier().to(device)
path_to_model = 'trained_models/'
model_feature_source.load_state_dict(torch.load(path_to_model+'source_feature_rd_128_SGD_sntg.pt'))
model_feature_target.load_state_dict(torch.load(path_to_model+'source_feature_rd_128_SGD_sntg.pt'))
model_clf.load_state_dict(torch.load(path_to_model+'source_clf_rd_128_SGD_sntg.pt'))

half_batch = batch_size // 2
##########source data########################
source_dataset = load_SVHN_LS(target_num=1500,train_flag='train')
source_loader = DataLoader(source_dataset, batch_size=half_batch, drop_last=True,\
                           shuffle=True, num_workers=0, pin_memory=True)

##########target data########################
target_dataset_label = load_mnist_LS(source_num=290,train_flag=False)

target_loader = DataLoader(target_dataset_label, batch_size=half_batch,drop_last=True,\
                           shuffle=True, num_workers=1, pin_memory=True)

##initialize model optimizers
feature_source_optim = torch.optim.SGD(model_feature_source.parameters(), lr=lr,momentum=momentum) 
feature_target_optim = torch.optim.SGD(model_feature_target.parameters(), lr=lr,momentum=momentum) 
clf_optim = torch.optim.SGD(model_clf.parameters(), lr=lr,momentum=momentum) 

clf_criterion = nn.CrossEntropyLoss(reduction='none')
clf_criterion_unweight = nn.CrossEntropyLoss()
w1_loss = SamplesLoss(loss="sinkhorn", p=1, blur=.05)
soft_f = nn.Softmax(dim=1)

##initialize clustering weights
w_src = torch.ones(K,1,requires_grad=False).divide(K).to(device)
w_tgt = torch.ones(K,1,requires_grad=False).divide(K).to(device)
w_imp = torch.ones(K,1,requires_grad=False).to(device)

##initialize losses
mean_w1_loss_all = []
mean_clf_loss_all = []
mean_unweighted_clf_loss_all = []
mean_centroid_loss_all = []
mean_sntg_loss_all = []
mean_accuracy_all = []

for epoch in range(1, epochs+1):
    source_batch_iterator = iter(source_loader)
    target_batch_iterator = iter(target_loader)
    len_dataloader = min(len(source_loader), len(target_loader))

    total_unweight_clf_loss = 0
    total_clf_loss = 0
    total_centroid_loss = 0
    total_sntg_loss = 0
    total_w1_loss = 0


    for i in range(len_dataloader):

        data_source = source_batch_iterator.next()
        source_x, source_y = data_source
        data_target = target_batch_iterator.next()
        target_x, _ = data_target
        source_x, target_x = source_x.to(device), target_x.to(device)

        set_requires_grad(model_feature_source, requires_grad=True)
        set_requires_grad(model_feature_target, requires_grad=True)
        set_requires_grad(model_clf, requires_grad=True)

        ##extract latent features
        source_y = source_y.to(torch.int64).to(device)
        source_feature = model_feature_source(source_x)
        target_feature = model_feature_target(target_x)
        target_feature_2 = model_feature_target(target_x)

        ##unweighted classification loss
        source_preds = model_clf(source_feature)
        clf_loss_unweight = clf_criterion(source_preds, source_y)
        report_clf_loss_unweight = clf_criterion_unweight(source_preds, source_y)
        target_preds = torch.argmax(model_clf(target_feature_2),1)

        ##get clustering information
        source_y = source_y.to(torch.int64).to(device)
        cluster_s = F.one_hot(source_y, num_classes=K).float()
        target_y = target_preds.to(torch.int64).to(device)
        cluster_t = F.one_hot(target_y, num_classes=K).float()

        ##weighted classification loss
        weighted_clf_err = cluster_s * clf_loss_unweight.reshape(-1,1)
        expected_clf_err = torch.mean(weighted_clf_err, dim=0)
        clf_loss = torch.sum(expected_clf_err.reshape(K,1) * w_imp)

        ##weighted domain invariant loss   
        wasserstein_distance = 0
        for cluster_id in range(K):
            if (torch.sum(target_preds==cluster_id)!=0) and (torch.sum(source_y==cluster_id)!=0):
                wasserstein_distance += w_tgt[cluster_id]*w1_loss(source_feature[source_y==cluster_id,],\
                                         target_feature[target_preds==cluster_id,]) 

        ##clustering loss
        #L_orthogonal
        source_sntg_loss = sntg_loss_func(cluster=cluster_s,feature=source_feature,LAMBDA=LAMBDA)
        target_sntg_loss = sntg_loss_func(cluster=cluster_t,feature=target_feature,LAMBDA=LAMBDA)

        ##calculate centroids
        centroid_loss = centroid_loss_func(K,device,source_y,target_y,source_feature,target_feature)
        sntg_loss = source_sntg_loss + target_sntg_loss


        loss = lamb_clf*clf_loss + lamb_wd * wasserstein_distance + lamb_centroid*centroid_loss + lamb_sntg*sntg_loss

        #update weights
        with torch.no_grad():
            w_src_batch = cluster_s.mean(dim=0) 
            w_tgt_batch = cluster_t.mean(dim=0)
            w_src = w_src * (1 - weight_update) + w_src_batch.reshape(K,1) * weight_update
            w_tgt = w_tgt * (1 - weight_update) + w_tgt_batch.reshape(K,1) * weight_update
            w_imp = w_tgt/w_src


        #backprop feature extraction+classifier
        feature_source_optim.zero_grad()
        feature_target_optim.zero_grad()
        clf_optim.zero_grad()
        loss.backward()
        feature_source_optim.step()
        feature_target_optim.step()
        clf_optim.step()

        total_w1_loss += wasserstein_distance.item()
        total_unweight_clf_loss += report_clf_loss_unweight.item()
        total_clf_loss += clf_loss.item()
        total_centroid_loss += centroid_loss.item()
        total_sntg_loss += sntg_loss.item()


    mean_clf_loss = total_clf_loss/(len_dataloader)
    mean_unweighted_clf_loss = total_unweight_clf_loss/(len_dataloader)
    mean_centroid_loss = total_centroid_loss/(len_dataloader)
    mean_sntg_loss = total_sntg_loss/(len_dataloader)
    mean_w1_loss = total_w1_loss/(len_dataloader)


    mean_clf_loss_all.append(mean_clf_loss)
    mean_centroid_loss_all.append(mean_centroid_loss)
    mean_sntg_loss_all.append(mean_sntg_loss)
    mean_unweighted_clf_loss_all.append(mean_unweighted_clf_loss)
    mean_w1_loss_all.append(mean_w1_loss)


    tqdm.write(f'EPOCH {epoch:03d}: critic_loss={mean_w1_loss:.4f} source_clf={mean_clf_loss:.4f},unweighted_source_clf={mean_unweighted_clf_loss:.4f},clustering_centr={mean_centroid_loss:.4f},sntg={mean_sntg_loss:.4f}')


    eval_dataset = target_dataset_label

    eval_dataloader_all = DataLoader(eval_dataset, batch_size=len(eval_dataset), shuffle=False,
                        drop_last=False, num_workers=0, pin_memory=True)

    ##evaluate models on target domain##
    set_requires_grad(model_feature_source, requires_grad=False)
    set_requires_grad(model_feature_target, requires_grad=False)
    set_requires_grad(model_clf, requires_grad=False)

    total_accuracy = 0
    with torch.no_grad():
        for x, y_true in tqdm(eval_dataloader_all, leave=False):
            x, y_true = x.to(device), y_true.to(device)
            h_t = model_feature_target(x)
            y_pred = model_clf(h_t)
            cluster_t = F.one_hot(torch.argmax(y_pred,1),num_classes=K).float()
            cluster_t = np.argmax(cluster_t.cpu().detach().numpy(), axis=1)
            cluster_df = pd.DataFrame(cluster_t)
            print(cluster_df.iloc[:,0].value_counts())
            if epoch ==1:
                cluster_df_true = pd.DataFrame(y_true.cpu().detach().numpy())
                print(cluster_df_true.iloc[:,0].value_counts())
            total_accuracy += (y_pred.max(1)[1] == y_true).float().mean().item()


    mean_accuracy = total_accuracy / len(eval_dataloader_all)
    mean_accuracy_all.append(mean_accuracy)
    print(f'Accuracy on target data: {mean_accuracy:.4f}')

    #save all the losses
    #save_parameters(mean_clf_loss_all=mean_clf_loss_all,mean_unweighted_clf_loss_all=mean_unweighted_clf_loss_all,\
    #                mean_centroid_loss_all=mean_centroid_loss_all,mean_sntg_loss_all=mean_sntg_loss_all,\
    #                mean_accuracy_all=mean_accuracy_all, mean_w1_loss_all=mean_w1_loss_all,\
    #                key_words='clf1wd0.5c1s1'+str(momentum))


    if mean_accuracy > best_accu_t:
        best_accu_t = mean_accuracy
        #save the model with the best performance
        #torch.save(model_feature_source.state_dict(),'trained_models/wdgrl_source_clf1wd0.5c1s1_'+str(momentum)+'.pt')
        #torch.save(model_feature_target.state_dict(),'trained_models/wdgrl_target_clf1wd0.5c1s1_'+str(momentum)+'.pt')
        #torch.save(model_clf.state_dict(),'trained_models/wdgrl_clf_clf1wd0.5c1s1_'+str(momentum)+'.pt')

    print(best_accu_t)




Using downloaded and verified file: ../../data/svhn/train_32x32.mat


  0%|          | 0/1 [00:00<?, ?it/s]

EPOCH 001: critic_loss=31.9743 source_clf=0.2313,unweighted_source_clf=0.1991,clustering_centr=2.8359,sntg=13.6849


                                             

7    1060
4     838
3     811
5     694
2     546
6     530
1     404
9     314
8     311
0     292
Name: 0, dtype: int64
1    870
5    870
9    870
3    870
7    870
0    290
4    290
8    290
2    290
6    290
Name: 0, dtype: int64
Accuracy on target data: 0.6564
0.6563793420791626


  0%|          | 0/1 [00:00<?, ?it/s]

EPOCH 002: critic_loss=27.8615 source_clf=0.2589,unweighted_source_clf=0.2000,clustering_centr=1.8760,sntg=10.2727


                                             

4    1012
7     953
3     803
5     696
2     629
6     500
1     405
8     306
0     278
9     218
Name: 0, dtype: int64
Accuracy on target data: 0.6445
0.6563793420791626


  0%|          | 0/1 [00:00<?, ?it/s]

EPOCH 003: critic_loss=25.6267 source_clf=0.2254,unweighted_source_clf=0.1924,clustering_centr=1.6706,sntg=9.3895


                                             

7    1058
4     984
3     856
5     763
2     614
6     424
1     367
0     279
8     263
9     192
Name: 0, dtype: int64
Accuracy on target data: 0.6693
0.6693103909492493


  0%|          | 0/1 [00:00<?, ?it/s]

EPOCH 004: critic_loss=23.8616 source_clf=0.1986,unweighted_source_clf=0.1850,clustering_centr=1.5555,sntg=8.6091


                                             

7    1131
4    1026
3     840
5     782
2     532
6     433
1     365
0     274
8     258
9     159
Name: 0, dtype: int64
Accuracy on target data: 0.6779
0.6779310703277588


  0%|          | 0/1 [00:00<?, ?it/s]

EPOCH 005: critic_loss=23.2196 source_clf=0.1457,unweighted_source_clf=0.1666,clustering_centr=1.5773,sntg=8.1510


                                             

7    1204
4     991
3     855
5     809
2     486
6     398
1     344
0     272
8     259
9     182
Name: 0, dtype: int64
Accuracy on target data: 0.6988
0.6987931132316589


  0%|          | 0/1 [00:00<?, ?it/s]

EPOCH 006: critic_loss=22.3732 source_clf=0.2117,unweighted_source_clf=0.2209,clustering_centr=1.3107,sntg=7.7761


                                             

7    1186
4    1043
3     871
5     835
2     495
6     354
1     349
0     275
8     252
9     140
Name: 0, dtype: int64
Accuracy on target data: 0.6978
0.6987931132316589


  0%|          | 0/1 [00:00<?, ?it/s]

EPOCH 007: critic_loss=22.0358 source_clf=0.2160,unweighted_source_clf=0.1856,clustering_centr=1.2774,sntg=7.6671


                                             

7    1205
4    1035
3     873
5     838
2     483
1     338
6     335
0     272
8     256
9     165
Name: 0, dtype: int64
Accuracy on target data: 0.7038
0.7037931084632874


  0%|          | 0/1 [00:00<?, ?it/s]

EPOCH 008: critic_loss=21.5628 source_clf=0.1881,unweighted_source_clf=0.1912,clustering_centr=1.3372,sntg=7.3511


                                             

7    1201
4    1020
3     872
5     851
2     479
1     343
6     325
0     276
8     259
9     174
Name: 0, dtype: int64
Accuracy on target data: 0.7086
0.7086207270622253


  0%|          | 0/1 [00:00<?, ?it/s]

EPOCH 009: critic_loss=20.9614 source_clf=0.1614,unweighted_source_clf=0.1682,clustering_centr=1.4311,sntg=7.0949


                                             

7    1224
4    1020
3     858
5     852
2     456
1     365
6     339
0     271
8     262
9     153
Name: 0, dtype: int64
Accuracy on target data: 0.7160
0.7160345315933228


  0%|          | 0/1 [00:00<?, ?it/s]

EPOCH 010: critic_loss=20.2008 source_clf=0.2291,unweighted_source_clf=0.2083,clustering_centr=1.1451,sntg=6.9816


                                             

7    1202
4    1014
5     859
3     847
2     478
1     356
6     335
0     269
8     260
9     180
Name: 0, dtype: int64
Accuracy on target data: 0.7140
0.7160345315933228


  0%|          | 0/1 [00:00<?, ?it/s]

EPOCH 011: critic_loss=20.1619 source_clf=0.2035,unweighted_source_clf=0.1873,clustering_centr=1.1414,sntg=6.7178


                                             

7    1204
4    1029
5     868
3     866
2     458
1     363
6     333
0     263
8     258
9     158
Name: 0, dtype: int64
Accuracy on target data: 0.7164
0.7163793444633484


  0%|          | 0/1 [00:00<?, ?it/s]

EPOCH 012: critic_loss=20.0846 source_clf=0.1873,unweighted_source_clf=0.1941,clustering_centr=1.1973,sntg=6.5419


                                             

7    1183
4    1036
5     864
3     852
2     473
1     366
6     326
0     277
8     270
9     153
Name: 0, dtype: int64
Accuracy on target data: 0.7183
0.7182759046554565


  0%|          | 0/1 [00:00<?, ?it/s]

EPOCH 013: critic_loss=19.8522 source_clf=0.1640,unweighted_source_clf=0.1831,clustering_centr=1.1905,sntg=6.4398


                                             

7    1222
4    1030
5     877
3     864
2     448
1     347
6     323
0     272
8     267
9     150
Name: 0, dtype: int64
Accuracy on target data: 0.7226
0.7225862145423889


  0%|          | 0/1 [00:00<?, ?it/s]

EPOCH 014: critic_loss=19.3378 source_clf=0.1290,unweighted_source_clf=0.1537,clustering_centr=1.2024,sntg=6.2853


                                             

7    1258
4    1029
5     878
3     866
2     439
1     335
6     331
0     270
8     255
9     139
Name: 0, dtype: int64
Accuracy on target data: 0.7197
0.7225862145423889


  0%|          | 0/1 [00:00<?, ?it/s]

EPOCH 015: critic_loss=18.9139 source_clf=0.1547,unweighted_source_clf=0.1846,clustering_centr=1.0818,sntg=6.2365


                                             

7    1259
4    1052
5     870
3     850
2     433
1     348
6     327
0     274
8     265
9     122
Name: 0, dtype: int64
Accuracy on target data: 0.7224
0.7225862145423889


  0%|          | 0/1 [00:00<?, ?it/s]

EPOCH 016: critic_loss=18.5428 source_clf=0.1572,unweighted_source_clf=0.1838,clustering_centr=1.0379,sntg=5.9577


                                             

7    1254
4    1019
5     875
3     868
2     444
1     335
6     319
0     277
8     261
9     148
Name: 0, dtype: int64
Accuracy on target data: 0.7229
0.7229310274124146


  0%|          | 0/1 [00:00<?, ?it/s]

EPOCH 017: critic_loss=18.0851 source_clf=0.1488,unweighted_source_clf=0.1590,clustering_centr=1.0717,sntg=5.8245


                                             

7    1294
4    1026
5     864
3     857
2     414
1     333
6     324
0     273
8     263
9     152
Name: 0, dtype: int64
Accuracy on target data: 0.7291
0.7291379570960999


  0%|          | 0/1 [00:00<?, ?it/s]

EPOCH 018: critic_loss=18.0390 source_clf=0.1586,unweighted_source_clf=0.1698,clustering_centr=1.0549,sntg=5.8362


                                             

7    1269
4    1038
3     860
5     856
2     418
1     351
6     325
0     276
8     265
9     142
Name: 0, dtype: int64
Accuracy on target data: 0.7293
0.7293103933334351


  0%|          | 0/1 [00:00<?, ?it/s]

EPOCH 019: critic_loss=17.8690 source_clf=0.1380,unweighted_source_clf=0.1493,clustering_centr=1.0968,sntg=5.7149


                                             

7    1271
4    1031
5     869
3     861
2     438
1     331
6     323
0     277
8     263
9     136
Name: 0, dtype: int64
Accuracy on target data: 0.7245
0.7293103933334351


  0%|          | 0/1 [00:00<?, ?it/s]

EPOCH 020: critic_loss=17.5569 source_clf=0.1291,unweighted_source_clf=0.1650,clustering_centr=1.0246,sntg=5.7193


                                             

7    1288
4    1010
5     866
3     865
2     422
1     341
6     330
0     270
8     261
9     147
Name: 0, dtype: int64
Accuracy on target data: 0.7303
0.730344831943512


  0%|          | 0/1 [00:00<?, ?it/s]

EPOCH 021: critic_loss=17.9484 source_clf=0.1831,unweighted_source_clf=0.1824,clustering_centr=1.0313,sntg=5.6290


                                             

7    1267
4    1017
5     872
3     864
2     430
1     338
6     328
0     271
8     260
9     153
Name: 0, dtype: int64
Accuracy on target data: 0.7279
0.730344831943512


  0%|          | 0/1 [00:00<?, ?it/s]

EPOCH 022: critic_loss=17.2757 source_clf=0.1435,unweighted_source_clf=0.1681,clustering_centr=1.0775,sntg=5.4672


                                             

7    1264
4    1039
5     867
3     866
2     440
1     338
6     320
0     276
8     269
9     121
Name: 0, dtype: int64
Accuracy on target data: 0.7253
0.730344831943512


  0%|          | 0/1 [00:00<?, ?it/s]

EPOCH 023: critic_loss=17.4731 source_clf=0.1550,unweighted_source_clf=0.1620,clustering_centr=1.1466,sntg=5.6097


                                             

7    1231
4    1022
5     873
3     861
2     460
1     346
6     319
0     277
8     270
9     141
Name: 0, dtype: int64
Accuracy on target data: 0.7274
0.730344831943512


  0%|          | 0/1 [00:00<?, ?it/s]

EPOCH 024: critic_loss=17.4185 source_clf=0.1429,unweighted_source_clf=0.1466,clustering_centr=1.1451,sntg=5.4342


                                             

7    1262
4    1012
5     877
3     849
2     436
1     339
6     321
8     275
0     271
9     158
Name: 0, dtype: int64
Accuracy on target data: 0.7293
0.730344831943512


  0%|          | 0/1 [00:00<?, ?it/s]

EPOCH 025: critic_loss=17.0281 source_clf=0.1786,unweighted_source_clf=0.1622,clustering_centr=1.0256,sntg=5.4134


                                             

7    1262
4    1027
5     871
3     861
2     437
1     334
6     315
8     275
0     274
9     144
Name: 0, dtype: int64
Accuracy on target data: 0.7284
0.730344831943512


  0%|          | 0/1 [00:00<?, ?it/s]

EPOCH 026: critic_loss=16.8769 source_clf=0.1420,unweighted_source_clf=0.1790,clustering_centr=1.0589,sntg=5.4616


                                             

7    1273
4    1029
5     872
3     862
2     429
1     336
6     315
0     273
8     273
9     138
Name: 0, dtype: int64
Accuracy on target data: 0.7293
0.730344831943512


  0%|          | 0/1 [00:00<?, ?it/s]

EPOCH 027: critic_loss=16.7705 source_clf=0.1533,unweighted_source_clf=0.1558,clustering_centr=0.9721,sntg=5.2015


                                             

7    1250
4    1016
5     868
3     868
2     441
1     344
6     310
0     276
8     270
9     157
Name: 0, dtype: int64
Accuracy on target data: 0.7302
0.730344831943512


  0%|          | 0/1 [00:00<?, ?it/s]

EPOCH 028: critic_loss=16.7045 source_clf=0.1376,unweighted_source_clf=0.1599,clustering_centr=0.9573,sntg=5.1492


                                             

7    1245
4    1011
5     866
3     862
2     433
1     353
6     326
8     277
0     271
9     156
Name: 0, dtype: int64
Accuracy on target data: 0.7341
0.7341379523277283


  0%|          | 0/1 [00:00<?, ?it/s]

EPOCH 029: critic_loss=16.8224 source_clf=0.1366,unweighted_source_clf=0.1697,clustering_centr=0.9078,sntg=5.2671


                                             

7    1265
4    1017
5     869
3     860
2     432
1     343
6     329
8     271
0     264
9     150
Name: 0, dtype: int64
Accuracy on target data: 0.7328
0.7341379523277283


  0%|          | 0/1 [00:00<?, ?it/s]

EPOCH 030: critic_loss=16.3393 source_clf=0.1346,unweighted_source_clf=0.1686,clustering_centr=0.9044,sntg=5.1115


                                             

7    1227
4    1031
5     873
3     862
2     464
1     339
6     321
8     276
0     263
9     144
Name: 0, dtype: int64
Accuracy on target data: 0.7243
0.7341379523277283


  0%|          | 0/1 [00:00<?, ?it/s]

EPOCH 031: critic_loss=16.0583 source_clf=0.1385,unweighted_source_clf=0.1747,clustering_centr=0.9969,sntg=4.9227


                                             

7    1247
4    1024
5     872
3     858
2     437
1     351
6     319
8     270
0     267
9     155
Name: 0, dtype: int64
Accuracy on target data: 0.7307
0.7341379523277283


  0%|          | 0/1 [00:00<?, ?it/s]

EPOCH 032: critic_loss=16.0974 source_clf=0.1399,unweighted_source_clf=0.1624,clustering_centr=0.8775,sntg=4.9580


                                             

7    1240
4    1018
5     878
3     853
2     439
1     348
6     321
8     272
0     271
9     160
Name: 0, dtype: int64
Accuracy on target data: 0.7307
0.7341379523277283


  0%|          | 0/1 [00:00<?, ?it/s]

EPOCH 033: critic_loss=15.7061 source_clf=0.1552,unweighted_source_clf=0.1559,clustering_centr=0.9075,sntg=4.8493


                                             

7    1241
4    1023
5     874
3     859
2     438
1     352
6     320
0     276
8     274
9     143
Name: 0, dtype: int64
Accuracy on target data: 0.7329
0.7341379523277283


