# ライブラリのインポート

In [3]:
import numpy as np
import utils as ut
import torch.nn as nn
import torch.nn.functional as F
import torch
import torchvision.models as models
import multiprocessing

# Discriminator

In [4]:
class Discriminator(nn.Module):
    
    def __init__(self, item_num, user_num, emb_dim,
                           lamda, param=None, initdelta=0.05):
        super(Discriminator, self).__init__()
        self.item_num = item_num
        self.user_num = user_num
        self.emb_dim = emb_dim
        self.lamda = lamda
        self.param = param
        self.initdelta = initdelta
        self.d_params = []
        
        if self.param is None:
            self.user_embeddings = nn.Embedding(
                self.user_num, self.emb_dim)            
            self.item_embeddings = nn.Embedding(
                self.item_num, self.emb_dim)            
            self.item_bias = nn.Embedding(
                self.item_num, 1)
            
            # initialize
            torch.nn.init.uniform_(
                self.user_embeddings.weight, a=-initdelta, b=initdelta)
            torch.nn.init.uniform_(
                self.item_embeddings.weight, a=-initdelta, b=initdelta)
#             torch.nn.init.uniform_(
#                 self.item_bias.weight, a=initdelta, b=-initdelta)
            
    def forward(self, user, item, label):
        pre_logits = (self.user_embeddings(user) \
                              * self.item_embeddings(item)).squeeze().sum(1).view(-1,1) \
                                + self.item_bias(item).view(-1, 1)
        pre_loss =  F.binary_cross_entropy_with_logits(pre_logits, label)
        return pre_loss

    def all_rating(self, user):
        rating = torch.mm(self.user_embeddings(user).view(-1, 5),
                                               self.item_embeddings.weight.t()) + self.item_bias
        return all_rating

    def all_logits(self, user):
        logits = (self.user_embeddings(user) \
                              * self.item_embeddings.weight).sum(1) + self.item_bias
        return all_logits

    def get_reward(self, user, item):
        reward_logits = (self.user_embeddings(user) * self.item_embeddings(item)).squeeze().sum(1).view(-1,1) + self.item_bias(item).view(-1, 1)
        reward = 2 * (torch.sigmoid(reward_logits) - 0.5)
        return reward

# Generator

In [5]:
class Generator(nn.Module):
    
    def __init__(self, item_num, user_num, emb_dim,
                           lamda, param=None, initdelta=0.05, lr=0.05):
        super(Generator, self).__init__()
        self.item_num = item_num
        self.user_num = user_num
        self.emb_dim = emb_dim
        self.lamda = lamda
        self.param = param
        self.initdelta = initdelta
        self.lr = lr
        self.g_params = []
        
        import pickle
        with open("ml-100k/model_dns_ori.pkl", "rb") as f:
            param = pickle.load(f, encoding='latin1')
    
        if self.param is None:
            self.user_embeddings = nn.Embedding(
                self.user_num, self.emb_dim)
            self.item_embeddings = nn.Embedding(
                self.item_num, self.emb_dim)
            self.item_bias = torch.zeros(self.item_num)
            
        self.user_embeddings.weight = torch.nn.Parameter(torch.tensor(param[0]))
        self.item_embeddings.weight = torch.nn.Parameter(torch.tensor(param[1]))
        
            # initialize
#             torch.nn.init.uniform_(
#                 self.user_embeddings.weight, a=-initdelta, b=initdelta)
#             torch.nn.init.uniform_(
#                 self.item_embeddings.weight, a=-initdelta, b=initdelta)
        
    def forward(self, user, item, reward):
        softmax_score = F.softmax(self.all_logits(user).view(1, -1), -1)
        i_prob = torch.gather(softmax_score.view(-1), 0, item).clamp(min=1e-8)
        loss = - torch.mean(torch.log(i_prob) * reward) # \
#                    + self.lamda * (F.normalize(self.user_embeddings(user), p=2, dim=1) \
#                                              + F.normalize(self.item_embeddings(item), p=2, dim=1) \
#                                              + F.normalize(self.item_bias(item), p=2, dim=1))
        return loss

    def all_rating(self, user):
        rating = torch.mm(self.user_embeddings(user).view(-1, 5),
                                               self.item_embeddings.weight.t()) + self.item_bias
        return rating
    
    def all_logits(self, user):
        logits = (self.user_embeddings(user) \
                              * self.item_embeddings.weight).sum(1) + self.item_bias
        return logits

# 評価関数

In [6]:
def dcg_at_k(r, k):
    r = np.asfarray(r)[:k]
    return np.sum(r / np.log2(np.arange(2, r.size + 2)))


def ndcg_at_k(r, k):
    dcg_max = dcg_at_k(sorted(r, reverse=True), k)
    if not dcg_max:
        return 0.
    return dcg_at_k(r, k) / dcg_max

def simple_test_one_user(x):
    
    # import pdb; pdb.set_trace()
    rating = x[0]
    u = x[1]

    test_items = list(all_items - set(user_pos_train[u]))
    item_score = []
    for i in test_items:
        item_score.append((i, rating[i]))

    item_score = sorted(item_score, key=lambda x: x[1])
    item_score.reverse()
    item_sort = [x[0] for x in item_score]

    r = []
    for i in item_sort:
        if i in user_pos_test[u]:
            r.append(1)
        else:
            r.append(0)

    p_3 = np.mean(r[:3])
    p_5 = np.mean(r[:5])
    p_10 = np.mean(r[:10])
    
    ndcg_3 = ndcg_at_k(r, 3)
    ndcg_5 = ndcg_at_k(r, 5)
    ndcg_10 = ndcg_at_k(r, 10)

    return np.array([p_3, p_5, p_10, ndcg_3, ndcg_5, ndcg_10])

def simple_test(model):
    result = np.array([0.] * 6)
    pool = multiprocessing.Pool(cores)
    batch_size = 128
    test_users = list(user_pos_test.keys())
    test_user_num = len(test_users)
    index = 0
    
    while True:
        if index >= test_user_num:
            break
        user_batch = test_users[index:index + batch_size]
        index += batch_size
    
#         print(user_batch)
        user_batch_rating = model.all_rating(torch.tensor(user_batch))

        user_batch_rating = user_batch_rating.detach_().cpu().numpy()

        user_batch_rating_uid = zip(user_batch_rating, user_batch)
        batch_result = pool.map(simple_test_one_user, user_batch_rating_uid)
        for re in batch_result:
            result += re

    pool.close()
    ret = result / test_user_num
    ret = list(ret)
    return ret

# ハイパーパラメーター

In [7]:
EMB_DIM = 5
USER_NUM = 943
ITEM_NUM = 1683
BATCH_SIZE = 16
INIT_DELTA = 0.05

all_items = set(range(ITEM_NUM))
workdir = "ml-100k/"
DIS_TRAIN_FILE = workdir + "dis-train.txt"
cores = multiprocessing.cpu_count()

# ロードデータ

In [8]:
# positiveな要素だけを引っ張ってくる
user_pos_train = {}
with open(workdir + 'movielens-100k-train.txt')as fin:
    for line in fin:
        line = line.split()
        uid = int(line[0])
        iid = int(line[1])
        r = float(line[2])
        if r > 3.99:
            if uid in user_pos_train:
                user_pos_train[uid].append(iid)
            else:
                user_pos_train[uid] = [iid]

# testでnegativeな要素だけを引っ張ってくる                
user_pos_test = {}
with open(workdir + 'movielens-100k-test.txt')as fin:
    for line in fin:
        line = line.split()
        uid = int(line[0])
        iid = int(line[1])
        r = float(line[2])
        if r > 3.99:
            if uid in user_pos_test:
                user_pos_test[uid].append(iid)
            else:
                user_pos_test[uid] = [iid]

all_users = user_pos_train.keys()

# Generatorによるデータセレクト

In [9]:
def generate_for_d(model, filename):
    data = []
    for user in user_pos_train:
        pos = user_pos_train[user]

        rating = model.all_rating(torch.tensor(user))
        rating = rating.detach_().cpu().numpy()
        rating = np.array(rating) / 0.2  # Temperature
        exp_rating = np.exp(rating)
        prob = exp_rating / np.sum(exp_rating)

        neg = np.random.choice(np.arange(ITEM_NUM), size=len(pos), p=prob.reshape(-1,))
        for i in range(len(pos)):
            data.append(str(user) + '\t' + str(pos[i]) + '\t' + str(neg[i]))

    with open(filename, 'w')as fout:
        fout.write('\n'.join(data))

In [10]:
generator = Generator(
    ITEM_NUM, USER_NUM,EMB_DIM, lamda=0.0 / BATCH_SIZE,
    param=None, initdelta=INIT_DELTA)

discriminator = Discriminator(
    ITEM_NUM, USER_NUM,EMB_DIM, lamda=0.0 / BATCH_SIZE,
    param=None, initdelta=INIT_DELTA)

g_optimizer = torch.optim.SGD(
    generator.parameters(), lr=0.001, momentum=0.9)

d_optimizer = torch.optim.SGD(
    discriminator.parameters(), lr=0.001, momentum=0.9)                    

In [11]:
from torch import nn
from torch import autograd

generator.train()
discriminator.train()
for epoch in range(15):
    if epoch >= 0:
        for d_epoch in range(100):
            if d_epoch % 5 == 0:
                generate_for_d(generator, DIS_TRAIN_FILE)
                train_size = ut.file_len(DIS_TRAIN_FILE)
            index = 1
            while True:
                if index > train_size:
                    break
                if index + BATCH_SIZE <= train_size + 1:
                    users, items, labels = ut.get_batch_data(
                        DIS_TRAIN_FILE, index, BATCH_SIZE)
                else:
                    users, items, labels = ut.get_batch_data(
                        DIS_TRAIN_FILE, index, train_size - index + 1)
                    
                index += BATCH_SIZE
                users = torch.tensor(users).view(-1, 1)
                items = torch.tensor(items).view(-1, 1)
                labels = torch.tensor(labels).view(-1, 1)
    
                d_loss = discriminator(users, items, labels)
                d_optimizer.zero_grad()
                d_loss.backward()
                d_optimizer.step()
            print("\r[D Epoch %d/%d] [loss: %f]" %(d_epoch, 100, d_loss.item()))

        for g_epoch in range(50):
            for user in user_pos_train:
                sample_lambda = 0.2
                pos = user_pos_train[user]

                rating = generator.all_logits(torch.tensor(user))
                rating = rating.detach_().cpu().numpy()
                exp_rating = np.exp(rating)
                prob = exp_rating / np.sum(exp_rating)
                
                pn = (1 - sample_lambda) * prob
                pn[pos] += sample_lambda * 1.0 / len(pos)
                
                sample = np.random.choice(range(ITEM_NUM), 2*len(pos), p=pn)
                reward = discriminator.get_reward(torch.tensor([user] * 2 * len(pos)), torch.tensor(sample))
                reward = reward.detach_().cpu().numpy()
                reward = reward * prob[sample] / pn[sample]

                g_loss = generator(torch.tensor(user), torch.tensor(sample), torch.tensor(reward))
                g_optimizer.zero_grad()
                g_loss.backward()
                g_optimizer.step()
                
            print("\r[G Epoch %d/%d] [loss: %f]" %(g_epoch, 50, g_loss.item()))
            result = simple_test(generator)
            print("epoch ", epoch, "gen: ", result)
#             buf = '\t'.join([str(x) for x in result])
#             gen_log.write(str(epoch) + '\t' + buf + '\n')
#             gen_log.flush()

[D Epoch 0/100] [loss: 0.861321]
[D Epoch 1/100] [loss: 0.860576]
[D Epoch 2/100] [loss: 0.859837]
[D Epoch 3/100] [loss: 0.859101]
[D Epoch 4/100] [loss: 0.858370]
[D Epoch 5/100] [loss: 0.857644]
[D Epoch 6/100] [loss: 0.856921]
[D Epoch 7/100] [loss: 0.856204]
[D Epoch 8/100] [loss: 0.855491]
[D Epoch 9/100] [loss: 0.854781]
[D Epoch 10/100] [loss: 0.854076]
[D Epoch 11/100] [loss: 0.853376]
[D Epoch 12/100] [loss: 0.852679]
[D Epoch 13/100] [loss: 0.851986]
[D Epoch 14/100] [loss: 0.851298]
[D Epoch 15/100] [loss: 0.850614]
[D Epoch 16/100] [loss: 0.849934]
[D Epoch 17/100] [loss: 0.849258]
[D Epoch 18/100] [loss: 0.848586]
[D Epoch 19/100] [loss: 0.847917]
[D Epoch 20/100] [loss: 0.847254]
[D Epoch 21/100] [loss: 0.846593]
[D Epoch 22/100] [loss: 0.845937]
[D Epoch 23/100] [loss: 0.845286]
[D Epoch 24/100] [loss: 0.844638]
[D Epoch 25/100] [loss: 0.843994]
[D Epoch 26/100] [loss: 0.843354]
[D Epoch 27/100] [loss: 0.842718]
[D Epoch 28/100] [loss: 0.842085]
[D Epoch 29/100] [loss: 

[G Epoch 28/50] [loss: 0.280824]
epoch  0 gen:  [0.33187134502924009, 0.30614035087719282, 0.27149122807017512, 0.3434592720003079, 0.32644494241825239, 0.31290886054927891]
[G Epoch 29/50] [loss: 0.361999]
epoch  0 gen:  [0.33260233918128684, 0.30570175438596475, 0.27149122807017512, 0.34410857443663656, 0.3262735598854653, 0.31300375876306474]
[G Epoch 30/50] [loss: 0.456485]
epoch  0 gen:  [0.33260233918128684, 0.30570175438596475, 0.27105263157894699, 0.34410857443663656, 0.3262735598854653, 0.31273831198759394]
[G Epoch 31/50] [loss: 0.295187]
epoch  0 gen:  [0.33260233918128684, 0.30614035087719282, 0.27149122807017512, 0.34397383200691184, 0.32646390838418166, 0.31294053501000313]
[G Epoch 32/50] [loss: -0.055654]
epoch  0 gen:  [0.33260233918128684, 0.30570175438596475, 0.27127192982456105, 0.34397383200691178, 0.32617617795104048, 0.31278622440284981]
[G Epoch 33/50] [loss: 0.416605]
epoch  0 gen:  [0.33260233918128684, 0.30570175438596475, 0.27127192982456105, 0.3439738320069

[G Epoch 6/50] [loss: -0.422090]
epoch  1 gen:  [0.33187134502924009, 0.30438596491228059, 0.2710526315789471, 0.3434592720003079, 0.32532661558114928, 0.31264706868523251]
[G Epoch 7/50] [loss: -0.058768]
epoch  1 gen:  [0.33260233918128684, 0.30438596491228059, 0.2710526315789471, 0.34410857443663651, 0.32547555837696462, 0.31277791009066497]
[G Epoch 8/50] [loss: 0.331407]
epoch  1 gen:  [0.33260233918128684, 0.30482456140350866, 0.2710526315789471, 0.34410857443663651, 0.32573069391464449, 0.31276217511505983]
[G Epoch 9/50] [loss: 0.080864]
epoch  1 gen:  [0.33260233918128684, 0.30526315789473668, 0.2710526315789471, 0.34410857443663651, 0.32601842434778555, 0.31275559992813556]
[G Epoch 10/50] [loss: -0.410067]
epoch  1 gen:  [0.33260233918128684, 0.30526315789473668, 0.2710526315789471, 0.34410857443663651, 0.32601842434778555, 0.31275559992813556]
[G Epoch 11/50] [loss: 0.182193]
epoch  1 gen:  [0.33260233918128684, 0.30526315789473668, 0.2710526315789471, 0.34410857443663651, 

[D Epoch 15/100] [loss: 0.769051]
[D Epoch 16/100] [loss: 0.768817]
[D Epoch 17/100] [loss: 0.768585]
[D Epoch 18/100] [loss: 0.768354]
[D Epoch 19/100] [loss: 0.768124]
[D Epoch 20/100] [loss: 0.767895]
[D Epoch 21/100] [loss: 0.767667]
[D Epoch 22/100] [loss: 0.767441]
[D Epoch 23/100] [loss: 0.767215]
[D Epoch 24/100] [loss: 0.766990]
[D Epoch 25/100] [loss: 0.766767]
[D Epoch 26/100] [loss: 0.766544]
[D Epoch 27/100] [loss: 0.766323]
[D Epoch 28/100] [loss: 0.766102]
[D Epoch 29/100] [loss: 0.765883]
[D Epoch 30/100] [loss: 0.765664]
[D Epoch 31/100] [loss: 0.765447]
[D Epoch 32/100] [loss: 0.765230]
[D Epoch 33/100] [loss: 0.765015]
[D Epoch 34/100] [loss: 0.764800]
[D Epoch 35/100] [loss: 0.764587]
[D Epoch 36/100] [loss: 0.764374]
[D Epoch 37/100] [loss: 0.764163]
[D Epoch 38/100] [loss: 0.763952]
[D Epoch 39/100] [loss: 0.763742]
[D Epoch 40/100] [loss: 0.763534]
[D Epoch 41/100] [loss: 0.763326]
[D Epoch 42/100] [loss: 0.763119]
[D Epoch 43/100] [loss: 0.762913]
[D Epoch 44/10

[G Epoch 31/50] [loss: -0.424095]
epoch  2 gen:  [0.3340643274853804, 0.30614035087719282, 0.27192982456140319, 0.34540717930929388, 0.32693907228371899, 0.31425776413439721]
[G Epoch 32/50] [loss: -0.052214]
epoch  2 gen:  [0.3340643274853804, 0.30614035087719282, 0.27192982456140319, 0.34540717930929388, 0.32693907228371899, 0.31427585280694725]
[G Epoch 33/50] [loss: -0.341111]
epoch  2 gen:  [0.3340643274853804, 0.30614035087719282, 0.27192982456140319, 0.34540717930929388, 0.32693907228371899, 0.31427585280694725]
[G Epoch 34/50] [loss: -0.248846]
epoch  2 gen:  [0.3340643274853804, 0.30614035087719282, 0.27192982456140319, 0.34540717930929388, 0.32693907228371899, 0.31428281990888235]
[G Epoch 35/50] [loss: -0.186603]
epoch  2 gen:  [0.3340643274853804, 0.30570175438596481, 0.27192982456140319, 0.34540717930929388, 0.32665134185057793, 0.31425104925686836]
[G Epoch 36/50] [loss: 0.009240]
epoch  2 gen:  [0.33333333333333365, 0.30614035087719282, 0.27192982456140319, 0.34489261930

[G Epoch 9/50] [loss: -0.281532]
epoch  3 gen:  [0.33333333333333365, 0.30614035087719282, 0.27171052631578912, 0.34451280172581084, 0.32658041227129891, 0.31383050312631566]
[G Epoch 10/50] [loss: -0.902316]
epoch  3 gen:  [0.33333333333333365, 0.30657894736842095, 0.27149122807017512, 0.34451280172581084, 0.32686814270444009, 0.3137286438580521]
[G Epoch 11/50] [loss: -0.175678]
epoch  3 gen:  [0.3340643274853804, 0.30614035087719282, 0.27149122807017512, 0.34502736173241472, 0.32663197313268955, 0.31375904919993264]
[G Epoch 12/50] [loss: -0.221176]
epoch  3 gen:  [0.3340643274853804, 0.30657894736842095, 0.27171052631578912, 0.34489261930269, 0.3268549165268671, 0.31382582110787466]
[G Epoch 13/50] [loss: -0.669141]
epoch  3 gen:  [0.3340643274853804, 0.30657894736842095, 0.27171052631578912, 0.34489261930269, 0.32682232163140579, 0.3137890778066183]
[G Epoch 14/50] [loss: -1.009671]
epoch  3 gen:  [0.3340643274853804, 0.30657894736842095, 0.27171052631578912, 0.34489261930269, 0.3

[D Epoch 31/100] [loss: 0.736101]
[D Epoch 32/100] [loss: 0.736003]
[D Epoch 33/100] [loss: 0.735905]
[D Epoch 34/100] [loss: 0.735807]
[D Epoch 35/100] [loss: 0.735709]
[D Epoch 36/100] [loss: 0.735612]
[D Epoch 37/100] [loss: 0.735515]
[D Epoch 38/100] [loss: 0.735419]
[D Epoch 39/100] [loss: 0.735322]
[D Epoch 40/100] [loss: 0.735226]
[D Epoch 41/100] [loss: 0.735131]
[D Epoch 42/100] [loss: 0.735035]
[D Epoch 43/100] [loss: 0.734940]
[D Epoch 44/100] [loss: 0.734846]
[D Epoch 45/100] [loss: 0.734751]
[D Epoch 46/100] [loss: 0.734657]
[D Epoch 47/100] [loss: 0.734563]
[D Epoch 48/100] [loss: 0.734469]
[D Epoch 49/100] [loss: 0.734376]
[D Epoch 50/100] [loss: 0.734283]
[D Epoch 51/100] [loss: 0.734190]
[D Epoch 52/100] [loss: 0.734098]
[D Epoch 53/100] [loss: 0.734005]
[D Epoch 54/100] [loss: 0.733913]
[D Epoch 55/100] [loss: 0.733822]
[D Epoch 56/100] [loss: 0.733730]
[D Epoch 57/100] [loss: 0.733639]
[D Epoch 58/100] [loss: 0.733548]
[D Epoch 59/100] [loss: 0.733458]
[D Epoch 60/10

[G Epoch 34/50] [loss: -0.622389]
epoch  4 gen:  [0.33406432748538045, 0.30614035087719293, 0.27127192982456111, 0.34592173931589776, 0.32727836357825074, 0.31389145417553554]
[G Epoch 35/50] [loss: -0.415196]
epoch  4 gen:  [0.33406432748538045, 0.30614035087719293, 0.27127192982456111, 0.34578699688617293, 0.32713368016580546, 0.31377095598094173]
[G Epoch 36/50] [loss: -0.732418]
epoch  4 gen:  [0.33406432748538045, 0.30614035087719293, 0.27127192982456111, 0.34578699688617293, 0.32713368016580546, 0.31379449592164049]
[G Epoch 37/50] [loss: -0.477341]
epoch  4 gen:  [0.33406432748538045, 0.30657894736842095, 0.27149122807017512, 0.34578699688617293, 0.32742141059894653, 0.31394303145147917]
[G Epoch 38/50] [loss: -0.415093]
epoch  4 gen:  [0.33406432748538045, 0.30614035087719293, 0.27149122807017512, 0.34592173931589776, 0.32724576868278943, 0.31405562575858897]
[G Epoch 39/50] [loss: -0.720103]
epoch  4 gen:  [0.33406432748538045, 0.30657894736842095, 0.27149122807017512, 0.34592

[G Epoch 12/50] [loss: -0.981374]
epoch  5 gen:  [0.33406432748538045, 0.30614035087719282, 0.27127192982456105, 0.34578699688617304, 0.32727400073037127, 0.31387034351544341]
[G Epoch 13/50] [loss: -0.815192]
epoch  5 gen:  [0.33479532163742726, 0.30614035087719282, 0.27127192982456105, 0.34630155689277686, 0.32732556159176185, 0.31391963992073862]
[G Epoch 14/50] [loss: -1.208696]
epoch  5 gen:  [0.33479532163742726, 0.30614035087719282, 0.27127192982456105, 0.34630155689277686, 0.32732556159176185, 0.3139138648434241]
[G Epoch 15/50] [loss: -0.824445]
epoch  5 gen:  [0.33479532163742726, 0.30614035087719282, 0.27127192982456105, 0.34630155689277686, 0.32726037180083922, 0.31387321837023174]
[G Epoch 16/50] [loss: -0.987069]
epoch  5 gen:  [0.33479532163742726, 0.30614035087719282, 0.27127192982456105, 0.34630155689277686, 0.32726037180083922, 0.31387321837023174]
[G Epoch 17/50] [loss: -0.369346]
epoch  5 gen:  [0.33479532163742726, 0.30570175438596475, 0.27127192982456105, 0.346301

[D Epoch 46/100] [loss: 0.719519]
[D Epoch 47/100] [loss: 0.719453]
[D Epoch 48/100] [loss: 0.719387]
[D Epoch 49/100] [loss: 0.719321]
[D Epoch 50/100] [loss: 0.719254]
[D Epoch 51/100] [loss: 0.719188]
[D Epoch 52/100] [loss: 0.719122]
[D Epoch 53/100] [loss: 0.719056]
[D Epoch 54/100] [loss: 0.718989]
[D Epoch 55/100] [loss: 0.718923]
[D Epoch 56/100] [loss: 0.718857]
[D Epoch 57/100] [loss: 0.718790]
[D Epoch 58/100] [loss: 0.718724]
[D Epoch 59/100] [loss: 0.718658]
[D Epoch 60/100] [loss: 0.718591]
[D Epoch 61/100] [loss: 0.718525]
[D Epoch 62/100] [loss: 0.718459]
[D Epoch 63/100] [loss: 0.718393]
[D Epoch 64/100] [loss: 0.718326]
[D Epoch 65/100] [loss: 0.718260]
[D Epoch 66/100] [loss: 0.718194]
[D Epoch 67/100] [loss: 0.718127]
[D Epoch 68/100] [loss: 0.718061]
[D Epoch 69/100] [loss: 0.717995]
[D Epoch 70/100] [loss: 0.717928]
[D Epoch 71/100] [loss: 0.717862]
[D Epoch 72/100] [loss: 0.717795]
[D Epoch 73/100] [loss: 0.717729]
[D Epoch 74/100] [loss: 0.717662]
[D Epoch 75/10

[G Epoch 37/50] [loss: -1.350487]
epoch  6 gen:  [0.33406432748538045, 0.30701754385964902, 0.27061403508771903, 0.34605648174562242, 0.32797903567458048, 0.31366969770917247]
[G Epoch 38/50] [loss: -0.944483]
epoch  6 gen:  [0.33406432748538045, 0.30701754385964902, 0.27061403508771903, 0.34605648174562242, 0.32797903567458048, 0.31366969770917247]
[G Epoch 39/50] [loss: -0.730021]
epoch  6 gen:  [0.33406432748538045, 0.30701754385964902, 0.27061403508771903, 0.34605648174562242, 0.32794644077911916, 0.3136502030928397]
[G Epoch 40/50] [loss: -1.209208]
epoch  6 gen:  [0.33406432748538045, 0.30701754385964902, 0.27061403508771903, 0.34592173931589765, 0.32784905884469434, 0.31358820090534506]
[G Epoch 41/50] [loss: -0.433412]
epoch  6 gen:  [0.33479532163742726, 0.30701754385964902, 0.27061403508771903, 0.34643629932250158, 0.32786802481062355, 0.31359765925434596]
[G Epoch 42/50] [loss: -1.072343]
epoch  6 gen:  [0.33479532163742726, 0.30701754385964902, 0.27061403508771903, 0.346436

[G Epoch 15/50] [loss: -1.083907]
epoch  7 gen:  [0.33479532163742726, 0.3065789473684209, 0.2710526315789471, 0.34681611689938074, 0.3279199884239733, 0.31426874736931526]
[G Epoch 16/50] [loss: -1.010090]
epoch  7 gen:  [0.33479532163742726, 0.3065789473684209, 0.2710526315789471, 0.34708560175883024, 0.32815211278812284, 0.31447649153714191]
[G Epoch 17/50] [loss: -1.087270]
epoch  7 gen:  [0.33479532163742726, 0.3065789473684209, 0.27127192982456116, 0.34695085932910547, 0.32805473085369802, 0.31455308570163182]
[G Epoch 18/50] [loss: -0.925459]
epoch  7 gen:  [0.33552631578947406, 0.3065789473684209, 0.27127192982456116, 0.34746541933570935, 0.32813888661054991, 0.31461848213185661]
[G Epoch 19/50] [loss: -1.152168]
epoch  7 gen:  [0.33552631578947406, 0.3065789473684209, 0.27127192982456116, 0.34746541933570935, 0.32813888661054991, 0.31460769702949887]
[G Epoch 20/50] [loss: -1.051931]
epoch  7 gen:  [0.33479532163742726, 0.30614035087719282, 0.27127192982456116, 0.3469508593291

[D Epoch 61/100] [loss: 0.703505]
[D Epoch 62/100] [loss: 0.703411]
[D Epoch 63/100] [loss: 0.703318]
[D Epoch 64/100] [loss: 0.703224]
[D Epoch 65/100] [loss: 0.703129]
[D Epoch 66/100] [loss: 0.703035]
[D Epoch 67/100] [loss: 0.702940]
[D Epoch 68/100] [loss: 0.702845]
[D Epoch 69/100] [loss: 0.702749]
[D Epoch 70/100] [loss: 0.702653]
[D Epoch 71/100] [loss: 0.702557]
[D Epoch 72/100] [loss: 0.702461]
[D Epoch 73/100] [loss: 0.702364]
[D Epoch 74/100] [loss: 0.702267]
[D Epoch 75/100] [loss: 0.702170]
[D Epoch 76/100] [loss: 0.702072]
[D Epoch 77/100] [loss: 0.701974]
[D Epoch 78/100] [loss: 0.701876]
[D Epoch 79/100] [loss: 0.701777]
[D Epoch 80/100] [loss: 0.701679]
[D Epoch 81/100] [loss: 0.701579]
[D Epoch 82/100] [loss: 0.701480]
[D Epoch 83/100] [loss: 0.701380]
[D Epoch 84/100] [loss: 0.701279]
[D Epoch 85/100] [loss: 0.701179]
[D Epoch 86/100] [loss: 0.701078]
[D Epoch 87/100] [loss: 0.700977]
[D Epoch 88/100] [loss: 0.700875]
[D Epoch 89/100] [loss: 0.700773]
[D Epoch 90/10

epoch  8 gen:  [0.33479532163742726, 0.30482456140350866, 0.27127192982456111, 0.34771660380594738, 0.32704414260386322, 0.31529597376060386]
[G Epoch 40/50] [loss: -1.080063]
epoch  8 gen:  [0.33479532163742726, 0.30438596491228059, 0.27127192982456111, 0.34785134623567215, 0.32682119920968566, 0.31533278356783168]
[G Epoch 41/50] [loss: -1.360160]
epoch  8 gen:  [0.33479532163742726, 0.30438596491228059, 0.27127192982456111, 0.34785134623567215, 0.32682119920968566, 0.31532446424871702]
[G Epoch 42/50] [loss: -1.420154]
epoch  8 gen:  [0.33479532163742726, 0.30438596491228059, 0.2710526315789471, 0.34785134623567215, 0.32682119920968566, 0.31516938546183104]
[G Epoch 43/50] [loss: -0.732671]
epoch  8 gen:  [0.33552631578947406, 0.30438596491228059, 0.2710526315789471, 0.34823116381255131, 0.32677537813665136, 0.31514857263888368]
[G Epoch 44/50] [loss: -1.774295]
epoch  8 gen:  [0.33552631578947406, 0.30394736842105258, 0.2710526315789471, 0.34823116381255131, 0.3264876477035103, 0.3

[G Epoch 17/50] [loss: -1.446187]
epoch  9 gen:  [0.33552631578947406, 0.30438596491228059, 0.27127192982456116, 0.34836590624227598, 0.3269053549665375, 0.31542061022942186]
[G Epoch 18/50] [loss: -1.523693]
epoch  9 gen:  [0.33552631578947406, 0.30394736842105258, 0.2710526315789471, 0.34823116381255131, 0.32652024259897156, 0.31520110534857243]
[G Epoch 19/50] [loss: -1.246497]
epoch  9 gen:  [0.33552631578947406, 0.30394736842105258, 0.27127192982456111, 0.34823116381255131, 0.32652024259897156, 0.31535707113673184]
[G Epoch 20/50] [loss: -1.311283]
epoch  9 gen:  [0.33552631578947406, 0.30394736842105252, 0.27127192982456116, 0.34836590624227598, 0.32658502963793506, 0.31539329138023697]
[G Epoch 21/50] [loss: -1.063525]
epoch  9 gen:  [0.33552631578947406, 0.30350877192982445, 0.27127192982456116, 0.34836590624227598, 0.32629729920479394, 0.31536987570298802]
[G Epoch 22/50] [loss: -1.085476]
epoch  9 gen:  [0.33552631578947406, 0.30350877192982445, 0.27127192982456116, 0.3487457

[D Epoch 72/100] [loss: 0.674282]
[D Epoch 73/100] [loss: 0.674082]
[D Epoch 74/100] [loss: 0.673882]
[D Epoch 75/100] [loss: 0.673682]
[D Epoch 76/100] [loss: 0.673480]
[D Epoch 77/100] [loss: 0.673278]
[D Epoch 78/100] [loss: 0.673075]
[D Epoch 79/100] [loss: 0.672871]
[D Epoch 80/100] [loss: 0.672667]
[D Epoch 81/100] [loss: 0.672462]
[D Epoch 82/100] [loss: 0.672256]
[D Epoch 83/100] [loss: 0.672049]
[D Epoch 84/100] [loss: 0.671842]
[D Epoch 85/100] [loss: 0.671634]
[D Epoch 86/100] [loss: 0.671425]
[D Epoch 87/100] [loss: 0.671216]
[D Epoch 88/100] [loss: 0.671005]
[D Epoch 89/100] [loss: 0.670794]
[D Epoch 90/100] [loss: 0.670583]
[D Epoch 91/100] [loss: 0.670370]
[D Epoch 92/100] [loss: 0.670157]
[D Epoch 93/100] [loss: 0.669943]
[D Epoch 94/100] [loss: 0.669728]
[D Epoch 95/100] [loss: 0.669512]
[D Epoch 96/100] [loss: 0.669296]
[D Epoch 97/100] [loss: 0.669079]
[D Epoch 98/100] [loss: 0.668861]
[D Epoch 99/100] [loss: 0.668642]
[G Epoch 0/50] [loss: -1.203949]
epoch  10 gen: 

epoch  10 gen:  [0.33479532163742737, 0.30526315789473663, 0.27039473684210502, 0.34771660380594743, 0.32732679833948292, 0.31476848127879398]
[G Epoch 42/50] [loss: -1.049392]
epoch  10 gen:  [0.33479532163742737, 0.30526315789473663, 0.27039473684210502, 0.34798608866539693, 0.32758675199925524, 0.31497063209873399]
[G Epoch 43/50] [loss: -0.545344]
epoch  10 gen:  [0.33552631578947412, 0.30526315789473663, 0.27017543859649096, 0.34823116381255137, 0.32746333027187668, 0.31477606654360868]
[G Epoch 44/50] [loss: -1.205735]
epoch  10 gen:  [0.33479532163742737, 0.30526315789473663, 0.27017543859649096, 0.34758186137622266, 0.32726201130051946, 0.31460680705795352]
[G Epoch 45/50] [loss: -0.896280]
epoch  10 gen:  [0.33479532163742737, 0.30526315789473663, 0.27017543859649096, 0.34809642138282659, 0.32763389749051242, 0.31487852233597829]
[G Epoch 46/50] [loss: -2.071220]
epoch  10 gen:  [0.33479532163742737, 0.30526315789473663, 0.27017543859649096, 0.34809642138282659, 0.327601302595

[G Epoch 19/50] [loss: -1.581798]
epoch  11 gen:  [0.33479532163742731, 0.30482456140350866, 0.27061403508771903, 0.34720204379934355, 0.3266997766118101, 0.31475901151020058]
[G Epoch 20/50] [loss: -1.179013]
epoch  11 gen:  [0.33479532163742731, 0.30482456140350866, 0.27061403508771903, 0.34720204379934355, 0.3266997766118101, 0.3147446120904242]
[G Epoch 21/50] [loss: -0.977630]
epoch  11 gen:  [0.33479532163742731, 0.30482456140350866, 0.27039473684210497, 0.34758186137622266, 0.32697428086737834, 0.3147503016320094]
[G Epoch 22/50] [loss: -1.201892]
epoch  11 gen:  [0.33479532163742731, 0.30482456140350866, 0.27061403508771897, 0.34758186137622266, 0.32697428086737834, 0.31489357256997358]
[G Epoch 23/50] [loss: -1.178962]
epoch  11 gen:  [0.33479532163742731, 0.30482456140350866, 0.27061403508771903, 0.34720204379934355, 0.3266997766118101, 0.31475013984859923]
[G Epoch 24/50] [loss: -1.244049]
epoch  11 gen:  [0.33479532163742731, 0.30482456140350866, 0.27061403508771903, 0.3472

[D Epoch 81/100] [loss: 0.615695]
[D Epoch 82/100] [loss: 0.615334]
[D Epoch 83/100] [loss: 0.614973]
[D Epoch 84/100] [loss: 0.614612]
[D Epoch 85/100] [loss: 0.614250]
[D Epoch 86/100] [loss: 0.613887]
[D Epoch 87/100] [loss: 0.613524]
[D Epoch 88/100] [loss: 0.613160]
[D Epoch 89/100] [loss: 0.612795]
[D Epoch 90/100] [loss: 0.612430]
[D Epoch 91/100] [loss: 0.612064]
[D Epoch 92/100] [loss: 0.611698]
[D Epoch 93/100] [loss: 0.611331]
[D Epoch 94/100] [loss: 0.610963]
[D Epoch 95/100] [loss: 0.610595]
[D Epoch 96/100] [loss: 0.610226]
[D Epoch 97/100] [loss: 0.609856]
[D Epoch 98/100] [loss: 0.609486]
[D Epoch 99/100] [loss: 0.609115]
[G Epoch 0/50] [loss: -1.495908]
epoch  12 gen:  [0.33406432748538051, 0.30438596491228059, 0.27061403508771897, 0.34744711894649788, 0.32684430403749215, 0.31495146923909945]
[G Epoch 1/50] [loss: -1.059547]
epoch  12 gen:  [0.33552631578947412, 0.30438596491228059, 0.27061403508771897, 0.34861098138943047, 0.32701221279923681, 0.31508806584847132]
[G

[G Epoch 43/50] [loss: -1.697978]
epoch  12 gen:  [0.33625730994152087, 0.30482456140350866, 0.27127192982456111, 0.34861098138943042, 0.32701728749675829, 0.31519114576619905]
[G Epoch 44/50] [loss: -1.395586]
epoch  12 gen:  [0.33552631578947412, 0.30482456140350866, 0.27127192982456111, 0.34809642138282643, 0.32696572663536771, 0.31518190208205005]
[G Epoch 45/50] [loss: -1.169480]
epoch  12 gen:  [0.33552631578947412, 0.30482456140350866, 0.27127192982456111, 0.34809642138282643, 0.32696572663536771, 0.31516506571135172]
[G Epoch 46/50] [loss: -1.512090]
epoch  12 gen:  [0.33552631578947412, 0.30482456140350866, 0.27149122807017517, 0.34809642138282643, 0.32699832153082903, 0.31531277454989171]
[G Epoch 47/50] [loss: -1.539401]
epoch  12 gen:  [0.33552631578947412, 0.30482456140350866, 0.2710526315789471, 0.34809642138282643, 0.32699832153082903, 0.31503885478100324]
[G Epoch 48/50] [loss: -1.174887]
epoch  12 gen:  [0.33479532163742731, 0.30526315789473668, 0.27127192982456111, 0.

[G Epoch 21/50] [loss: -1.595291]
epoch  13 gen:  [0.33552631578947412, 0.30394736842105258, 0.2710526315789471, 0.34782693652337704, 0.32630414463132396, 0.31491919538966168]
[G Epoch 22/50] [loss: -1.147838]
epoch  13 gen:  [0.33552631578947412, 0.30394736842105258, 0.2710526315789471, 0.34809642138282643, 0.3264989085001736, 0.31503118439411559]
[G Epoch 23/50] [loss: -1.272978]
epoch  13 gen:  [0.33552631578947412, 0.30394736842105258, 0.2710526315789471, 0.34809642138282643, 0.32653150339563491, 0.31506928880904073]
[G Epoch 24/50] [loss: -1.090888]
epoch  13 gen:  [0.33552631578947412, 0.30394736842105258, 0.2710526315789471, 0.34796167895310171, 0.3264667163566714, 0.31502970334053904]
[G Epoch 25/50] [loss: -1.336622]
epoch  13 gen:  [0.33552631578947412, 0.30394736842105258, 0.2710526315789471, 0.34809642138282643, 0.32653150339563491, 0.31506096059343686]
[G Epoch 26/50] [loss: -1.206918]
epoch  13 gen:  [0.33552631578947412, 0.30438596491228059, 0.2710526315789471, 0.3480964

[D Epoch 91/100] [loss: 0.530860]
[D Epoch 92/100] [loss: 0.530437]
[D Epoch 93/100] [loss: 0.530015]
[D Epoch 94/100] [loss: 0.529592]
[D Epoch 95/100] [loss: 0.529170]
[D Epoch 96/100] [loss: 0.528747]
[D Epoch 97/100] [loss: 0.528325]
[D Epoch 98/100] [loss: 0.527903]
[D Epoch 99/100] [loss: 0.527481]
[G Epoch 0/50] [loss: -1.159644]
epoch  14 gen:  [0.33698830409356767, 0.30438596491228059, 0.27149122807017517, 0.34950535897291346, 0.32722945470258669, 0.31560173715114542]
[G Epoch 1/50] [loss: -1.060254]
epoch  14 gen:  [0.33625730994152092, 0.30438596491228059, 0.27149122807017517, 0.34885605653658486, 0.32708051190677129, 0.31551175073359256]
[G Epoch 2/50] [loss: -1.582360]
epoch  14 gen:  [0.33698830409356767, 0.30438596491228059, 0.27171052631578918, 0.34937061654318874, 0.32713207276816186, 0.31568332736846627]
[G Epoch 3/50] [loss: -1.447076]
epoch  14 gen:  [0.33698830409356767, 0.30394736842105252, 0.27127192982456111, 0.3497504341200679, 0.32711884659058899, 0.3155581273

[G Epoch 45/50] [loss: -1.187895]
epoch  14 gen:  [0.33698830409356773, 0.30614035087719282, 0.2710526315789471, 0.34975043412006795, 0.32859009365175579, 0.31553503937668648]
[G Epoch 46/50] [loss: -2.101188]
epoch  14 gen:  [0.33698830409356773, 0.30570175438596475, 0.27083333333333304, 0.34937061654318874, 0.32802785896304654, 0.31515589117267656]
[G Epoch 47/50] [loss: -0.752619]
epoch  14 gen:  [0.33698830409356773, 0.30570175438596475, 0.27083333333333304, 0.34937061654318874, 0.32799526406758522, 0.31510395795756635]
[G Epoch 48/50] [loss: -1.839454]
epoch  14 gen:  [0.33698830409356773, 0.30570175438596475, 0.2710526315789471, 0.34937061654318874, 0.32806045385850785, 0.31532647508898493]
[G Epoch 49/50] [loss: -1.450778]
epoch  14 gen:  [0.33698830409356773, 0.30526315789473668, 0.27083333333333304, 0.34937061654318874, 0.3278053183208281, 0.31517756629748911]


In [116]:
generator.user_embeddings.weight = torch.nn.Parameter(torch.tensor(param[0]))
generator.item_embeddings.weight = torch.nn.Parameter(torch.tensor(param[1]))

In [108]:
generator

Generator(
  (user_embeddings): Embedding(943, 5)
  (item_embeddings): Embedding(1683, 5)
  (item_bias): Embedding(1683, 1)
)