In [17]:
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn as nn
import numpy as np
from models.SRCAGG import SRCAGG

import math
from lib.config import *

from lib.utils import PCA_TSNE
torch.backends.cudnn.deterministic = True
torch.cuda.manual_seed_all(seed)
torch.manual_seed(seed)
np.random.seed(seed)

import time
from models import queue_var

In [2]:


datasets.CIFAR10(root='../../data/', download=True, train=True)

classes = ['airplane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

print('device:', device)

Files already downloaded and verified
device: cpu


In [3]:
def load_rotated_cifar(left_out_idx):

    train_x = []
    test_x = []
    train_y = []
    test_y = []
    random = []
    for i in range(10):
        random.append(np.random.permutation(5000))

    for i in range(6):
        angle = 360 - 15 * i
        transform = transforms.Compose([transforms.RandomRotation(degrees=(angle, angle)), transforms.ToTensor()])
        cifar_train = datasets.CIFAR10(root='../../data/', download=False, train=True, transform=transform)
        train_loader = torch.utils.data.DataLoader(dataset=cifar_train, batch_size=50000, shuffle=False)

        full_data = next(iter(train_loader))

        targets = full_data[1]

        data = full_data[0]

        data_x = []
        data_y = []
        for j in range(10):
            idx = targets == j
            jth_target = targets[idx].to(device)
            jth_data = data[idx].to(device)
            jth_data = jth_data[random[j]]

            sample_x = jth_data[:400]
            sample_y = jth_target[:400]

            if i != left_out_idx:
                data_x.append(sample_x)
                data_y.append(sample_y)

            if i==left_out_idx:
                data_x.append(jth_data)
                data_y.append(jth_target)

        data_x = torch.cat(data_x).to(device)
        data_y = torch.cat(data_y).to(device)

        if i != left_out_idx:
            train_x.append(data_x)
            train_y.append(data_y)
        else:
            test_x = data_x
            test_y = data_y

    return train_x, train_y, test_x, test_y

In [4]:
dom,clas=5,[3, 9] # domain 5 (6th domain), setting 1: [classes 3,9 as unseen]

t1=time.time()
train_x, train_y, test_x, test_y = load_rotated_cifar(dom)

# train_x: list with len=N-1 ; N:Total # domains, -1 for the LOO domain. train_x[i]:torch.Size([4000, 3, 32, 32]): 4000 train images per domain
# train_y: list with len=N-1 : train_y[i]: torch.Size([4000]): labels for the i^th train domain data train_x[i]
# test_x: torch.Size([50000, 3, 32, 32]): data from entire domain, used for training
# test_y: torch.Size([50000]): labels for test_x

# print(type(train_x),type(train_y),type(test_x),type(test_y))
# <class 'list'> <class 'list'> <class 'torch.Tensor'> <class 'torch.Tensor'>
print('time elapsed: ',time.time()-t1)

time elapsed:  76.57342267036438


In [5]:
length_of_domain = len(train_x[0])
for i in range(len(train_x)):
    for k in clas:
        idx = train_y[i] != k
        train_y[i] = train_y[i][idx]
        train_x[i] = train_x[i][idx]
train_x = torch.cat(train_x)
train_y = torch.cat(train_y)

train_x = train_x.view(5, length_of_domain-800, 3, 32, 32) # just rearranging train domain data, for indexing
train_y = train_y.view(5, length_of_domain-800).long()

length_of_domain -= 800

In [6]:
print(train_x.size(),train_y.size()) # torch.Size([5, 3200, 3, 32, 32]) torch.Size([5, 3200])

torch.Size([5, 3200, 3, 32, 32]) torch.Size([5, 3200])


In [15]:
batch_size=50
print('------------------------')

model = SRCAGG()
for epoch in range(1):
    print('epoch:',epoch)
    x_train = []
    y_train = []

    random = np.random.permutation(length_of_domain)

    for i in range(5): #every chosen train domain
        x = train_x[i]
        x_permuted = x[random]

        y = train_y[i]
        y_permuted = y[random]

        x_train.append(x_permuted)
        y_train.append(y_permuted)

    x_train = torch.cat(x_train).to(device)
    x_train = x_train.view(5,length_of_domain,3,32,32)

    y_train  = torch.cat(y_train).to(device)
    y_train = y_train.view(5,length_of_domain)
    #print(x_train.size(),y_train.size()) #torch.Size([5, 3200, 3, 32, 32]) torch.Size([5, 3200])
    
    ## class ids might be disconnected due to the ZS setting, so, rename them as 0,1,...C-1
    unique_cids=list(y_train.unique())
    for ii in range(y_train.size(0)):
        for jj in range(y_train.size(1)):
            y_train[ii,jj]=unique_cids.index(y_train[ii,jj])

    ### Native ZSDG style
#     for i in range(5): # for every domain, accumulate/aggregate
#         avg_cost = 0
#         for k in range(0,length_of_domain,batch_size): #sample mini-batches

#             left_x = x_train[i][k:k+batch_size,:,:,:]
#             labels = y_train[i][k:k+batch_size]

#             avg_cost+= model.train(left_x,labels,i)

#         avg_cost = avg_cost/(length_of_domain/batch_size)

#     print(avg_cost)
    
    ### Domainbed style
    all_minibatches=[] # len(all_minibatches)=64 (3200/50), bsz=50
    for k in range(0,length_of_domain,batch_size): #sample mini-batches
        minibatches_device=[] #analogy to domainbed
        for i in range(5): # for every domain, accumulate/aggregate
            tmp_list=[]
            left_x = x_train[i][k:k+batch_size,:,:,:]
            labels = y_train[i][k:k+batch_size]
            tmp_list.append(left_x)
            tmp_list.append(labels)
            minibatches_device.append(tmp_list)
        all_minibatches.append(minibatches_device)
        
    ################################ Code required for SRCAGG ################################ 
    print('Firstly, computing Queues for the algorithm ')
    queue_sz = queue_var.queue_sz # the memory module/ queue size
    minibatches_device=all_minibatches[0]
    num_classes=len(y_train.unique())
    num_domains=len(minibatches_device)
    # pre-populate the global list of queues ...
    # Later, minibatches might have some classes with no eg, therefore, this step is necessary
    # (though it looks redundant), as we want to ensure a proper order of storage.
    train_queues=[] # create an adhoc variable for speed-up
    for id_c in range(num_classes):
        tmp_queues=[]
        for id_d in range(num_domains):
            tmp_queues.append(None)
        train_queues.append(tmp_queues)
    #queue_var.train_queues=train_queues # assign to the global variable

    # create an array to store flags to indicate whether queues have reached their required sizes
    flag_arr = np.zeros([num_classes, num_domains], dtype = int)

    #### The creation of the initial list of queues
    #train_queues=queue_var.train_queues # create an adhoc variable for speed-up 
    # assigning directly into the global variable caused slow speed.
    tpcnt=0
    while not np.sum(flag_arr)==num_classes*num_domains: #until all queues have queue_sz elts
        minibatches_device=all_minibatches[tpcnt]
        print('\n tpcnt: ',tpcnt,' ... Completed',np.sum(flag_arr),' queues out of ',num_classes*num_domains)
        for id_c in range(num_classes): # loop over classes
            for id_d in range(num_domains): # loop over domains
                if flag_arr[id_c][id_d]==1:
                    print('Queue (class ',id_c,', domain ',id_d,') is completely filled. ')
                    continue
                else:
                    mb_ids=(minibatches_device[id_d][1] == id_c).nonzero(as_tuple=True)[0]
                    # indices of those egs from domain id_d, whose class label is id_c
                    label_tensor=minibatches_device[id_d][1][mb_ids] # labels
                    if mb_ids.size(0)==0:
                        print('class has no element')
                        continue
                    data_tensor=minibatches_device[id_d][0][mb_ids] # data
                    data_emb = model.key_encoder(data_tensor) # extract features: torch.Size([negs, dim])
                    data_emb = data_emb.detach()
                    data_emb = torch.div(data_emb,torch.norm(data_emb,dim=1).reshape(-1,1))#l2 normalize

                    current_queue=train_queues[id_c][id_d]
                    if current_queue is None:
                        current_queue = data_emb
                    elif current_queue.size(0) < queue_sz:
                        current_queue = torch.cat((current_queue, data_emb), 0)    
                    if current_queue.size(0) > queue_sz:
                        # keep only the last queue_sz entries
                        current_queue = current_queue[-queue_sz:] # keep only the last queue_sz entries
                    if current_queue.size(0) == queue_sz:
                        flag_arr[id_c][id_d]=1
                    train_queues[id_c][id_d] = current_queue
                    print('Queue (class ',id_c,', domain ',id_d,') : ',train_queues[id_c][id_d].size())
        tpcnt+=1
        if tpcnt==len(all_minibatches):
            tpcnt=0
    queue_var.train_queues=train_queues # assign to the global variable

    # sanity checking the queues obtained
    for id_c in range(num_classes):
        for id_d in range(num_domains):
            print('Queue (class ',id_c,', domain ',id_d,') : ',queue_var.train_queues[id_c][id_d].size())

    model.atten.train()
    model.g_att.train()

    # if args.algorithm=='RCERM' or args.algorithm=='ERMR':
    #     model.atten.train()
    #     model.g_att.train()
    # if args.algorithm=='RCERMNG':
    #     model.atten.train()

    ################################ Code required for SRCAGG ################################ 

    mb_costs=[]
    for idx, minibatches in enumerate(all_minibatches):
        #if idx==2:
        #    break
        mb_cost=model.update(minibatches)
        mb_costs.append(mb_cost)
    avg_cost=sum(mb_costs)/len(mb_costs)    
    print('Epoch ',epoch,' avg_cost: ',avg_cost)
    
    

------------------------
epoch: 0
Firstly, computing Queues for the algorithm 

 tpcnt:  0  ... Completed 0  queues out of  40
Queue (class  0 , domain  0 ) :  torch.Size([3, 300])
Queue (class  0 , domain  1 ) :  torch.Size([3, 300])
Queue (class  0 , domain  2 ) :  torch.Size([3, 300])
Queue (class  0 , domain  3 ) :  torch.Size([3, 300])
Queue (class  0 , domain  4 ) :  torch.Size([3, 300])
Queue (class  1 , domain  0 ) :  torch.Size([5, 300])
Queue (class  1 , domain  1 ) :  torch.Size([5, 300])
Queue (class  1 , domain  2 ) :  torch.Size([5, 300])
Queue (class  1 , domain  3 ) :  torch.Size([5, 300])
Queue (class  1 , domain  4 ) :  torch.Size([5, 300])
Queue (class  2 , domain  0 ) :  torch.Size([7, 300])
Queue (class  2 , domain  1 ) :  torch.Size([7, 300])
Queue (class  2 , domain  2 ) :  torch.Size([7, 300])
Queue (class  2 , domain  3 ) :  torch.Size([7, 300])
Queue (class  2 , domain  4 ) :  torch.Size([7, 300])
Queue (class  3 , domain  0 ) :  torch.Size([5, 300])
Queue (cl

In [16]:
print('Mini-batch losses:')
for mb_cost in mb_costs:
    print(mb_cost)

Mini-batch losses:
9.407519340515137
3.2532966136932373
2.928440570831299
3.0497946739196777
2.9720096588134766
2.9710400104522705
2.962465286254883
2.877997875213623
2.934741258621216
2.834049940109253
2.8386850357055664
2.7761688232421875
2.8546862602233887
2.6452596187591553
2.7446041107177734
2.8998048305511475
2.810018301010132
2.6070337295532227
2.635939598083496
2.6736197471618652
2.6868464946746826
2.6608481407165527
2.6979854106903076
2.667520046234131
2.5456998348236084
2.872598648071289
2.6794185638427734
2.4835190773010254
2.6068661212921143
2.5312061309814453
2.6404917240142822
2.5849828720092773
2.6211063861846924
2.4989194869995117
2.5713486671447754
2.5774643421173096
2.4437623023986816
2.575110912322998
2.4591612815856934
2.4671895503997803
2.5472488403320312
2.5158355236053467
2.5241363048553467
2.5151278972625732
2.4306976795196533
2.4792332649230957
2.462926149368286
2.3605613708496094
2.6327853202819824
2.417515516281128
2.4441330432891846
2.386134147644043
2.45911