In [None]:
class DTC_trainer:
    """The class for controlling the training process of DEC"""
    
    
    def target_distribution(self,q):
        weight = q ** 2 / q.sum(0)
        return Variable(((weight.t() / weight.sum(1)).t()).data,requires_grad=False)
    
    def pretrain(self,train_loader, test_loader, epochs):
        
        dtc = DTC().cuda() #auto encoder
        mseloss = nn.MSELoss()
        optimizer = optim.Adam(dtc.parameters())
        
        for epoch in range(epochs):
            dtc.train()
            running_loss=0.0
            for i,data in enumerate(train_loader):
                x, label = data
                x,label=Variable(x).cuda(),Variable(label).cuda()
                optimizer.zero_grad()
                x_1,x_2 = dtc(x)
                td = self.target_distribution(x_1)
                loss =  F.kl_div(x_2,td,reduce=True) #mseloss(x_de,x) # so the aim is to minimize the reconstruct error
                loss.backward()
                optimizer.step()
                # print statistics
                running_loss += loss.data.cpu().numpy()[0]
                if i % 100 == 99:    # print every 2000 mini-batches
                    print('[%d, %5d] loss: %.7f' %
                          (epoch + 1, i + 1, running_loss / 100))
                    #print('x_de:',x_de, x)
                    running_loss = 0.0
            #now we evaluate the accuracy with AE
            dtc.eval()
            for i,data in enumerate(test_loader):
                x, label = data
                x=Variable(x).cuda()
                x_ae,_ = dtc(x)
                x_ae = x_ae.data.cpu().numpy()
                label = label.cpu().numpy()
                y_pred = np.argmax(x_ae,axis=1)
                print('y_pred',y_pred)
                print(' '*8 + '|==>  acc: %.4f,  nmi: %.4f  <==|'
                          % (acc(label, y_pred), nmi(label, y_pred)))
                break
import random
random.seed(7)
dtc = DTC_trainer()
dtc.pretrain(train_loader, test_loader, 20)

In [None]:
import numpy as np
from sklearn.metrics import normalized_mutual_info_score, adjusted_rand_score

nmi = normalized_mutual_info_score
ari = adjusted_rand_score


def acc(y_true, y_pred):
    """
    Calculate clustering accuracy. Require scikit-learn installed
    # Arguments
        y: true labels, numpy.array with shape `(n_samples,)`
        y_pred: predicted labels, numpy.array with shape `(n_samples,)`
    # Return
        accuracy, in [0,1]
    """
    y_true = y_true.astype(np.int64)
    assert y_pred.size == y_true.size
    D = max(y_pred.max(), y_true.max()) + 1
    w = np.zeros((D, D), dtype=np.int64)
    for i in range(y_pred.size):
        w[y_pred[i], y_true[i]] += 1
    from sklearn.utils.linear_assignment_ import linear_assignment
    ind = linear_assignment(w.max() - w)
    return sum([w[i, j] for i, j in ind]) * 1.0 / y_pred.size

In [None]:
class BasicUnit(nn.Module):
    def __init__(self):
        super(BasicUnit,self).__init__()
        self.dropout = nn.Dropout(p=0.1)
        self.conv_a1 = nn.Conv2d(1,50,4,stride=2,padding=2)
        self.conv_a2 = nn.Conv2d(50,50,5,stride=2,padding=2)
        self.leReLU = nn.LeakyReLU()
        self.fca1 = nn.Linear(50*9*9,68)
        self.softmax = nn.Softmax()
        
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                torch.nn.init.xavier_uniform(m.weight)

                
    def forward(self,x):
        # 32x32x1
        x = self.dropout(x)        
        # 32x32x1
        x = self.conv_a1(x)
        # 17x17x50
        x = self.leReLU(x)
        # 17x17x50
        x = self.dropout(x)
        # 17x17x50
        x = self.conv_a2(x)
        # 9x9x50
        x = self.leReLU(x)
        # 9x9x50
        x = self.dropout(x)
        # 9x9x50
        x = x.view(-1, 50*9*9)
        # 1x4050
        x = self.fca1(x)
        # 1x68
        x = self.softmax(x)
        return x
    
class DTC(nn.Module):
    def __init__(self):
        super(DTC,self).__init__()
        self.unit_a = BasicUnit()
        self.unit_b = BasicUnit()
    
    def forward(self,x):
        x1 = self.unit_a(x)
        x2 = self.unit_b(x)
        
        return x1,x2
