In [1]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim

In [20]:
class TF(nn.Module):
    def __init__(self, input_size, R):
        super(TF, self).__init__()
        self.I = input_size[0]
        self.J = input_size[1]
        self.K = input_size[2]
        self.U = nn.Parameter(torch.Tensor(R, self.I).uniform_(0,1), requires_grad=True)
        self.V = nn.Parameter(torch.Tensor(R, self.J).uniform_(0,1), requires_grad=True)
        self.W = nn.Parameter(torch.Tensor(R, self.K).uniform_(0,1), requires_grad=True)
        
    def non_negative(self, R=None):
#         torch.clamp(self.U, min=0)
#         torch.clamp(self.V, min=0)
#         torch.clamp(self.W, min=0)
        torch.clamp(self.U, min=0)
        torch.clamp(self.V, min=0)
        torch.clamp(self.W, min=0)

    def forward_one_rank(self, u, v, w):
        UV = torch.ger(u, v)
        UV2 = UV.unsqueeze(2).repeat(1,1,self.K)
        W2 = w.unsqueeze(0).unsqueeze(1).repeat(self.I, self.J, 1)
        outputs = UV2 * W2
        return outputs
    
    def forward(self, X=None):
        if X is not None:
            weight = (X != 0).float()
        else:
            weight = Variable(torch.ones(self.I, self.J, self.K))
        
        output = self.forward_one_rank(self.U[0], self.V[0], self.W[0])
        
        for i in np.arange(1, R):
            one_rank = self.forward_one_rank(self.U[i], self.V[i], self.W[i])
            output = output + one_rank
        
        return output

In [19]:
input_size = (11, 10, 9)
X = torch.Tensor(11, 10, 9).uniform_(0, 1)

R = 2
model = TF(input_size, R)
criterion = torch.nn.MSELoss(size_average=False)
optimizer = optim.Adam(model.parameters(), lr=0.001)

X = Variable(X)

epoch = 1000
for i in range(epoch):
    output = model.forward()
    loss = criterion(output, X)
    
    optimizer.zero_grad()
    
    loss.backward()
    optimizer.step()
    
    model.non_negative()
    
    if i % 100 == 0:
        print('{}: {}'.format(i, loss.data[0]))
        
print('{}: {}'.format(epoch, loss.data[0]))

0: 186.80499267578125
100: 139.28732299804688
200: 108.21112823486328
300: 91.53741455078125
400: 84.03804779052734
500: 81.01953125
600: 79.81192016601562
700: 79.2923583984375
800: 79.03129577636719
900: 78.8693618774414
1000: 78.75040435791016


In [15]:
model.non_negative()
print(model.U.shape, model.V.shape, model.W.shape)

torch.Size([2, 11]) torch.Size([2, 10]) torch.Size([2, 9])
