# ライブラリのインポート

In [22]:
import numpy as np
import utils as ut
import torch.nn as nn
import torch.nn.functional as F
import torch
import torchvision.models as models
import multiprocessing


# Discriminator

In [13]:
class Discriminator(nn.Module):
    
    def __init__(self, item_num, user_num, emb_dim,
                           lamda, param=None, initdelta=0.05):
        super(Discriminator, self).__init__()
        self.item_num = item_num
        self.user_num = user_num
        self.emb_dim = emb_dim
        self.lamda = lamda
        self.param = param
        self.initdelta = initdelta
        self.d_params = []
        
        if self.param is None:
            self.user_embeddings = nn.Embedding(
                self.user_num, self.emb_dim)            
            self.item_embeddings = nn.Embedding(
                self.item_num, self.emb_dim)            
            self.item_bias = nn.Embedding(
                self.item_num, 1)
            
            # initialize
            torch.nn.init.uniform_(
                self.user_embeddings.weight, a=-initdelta, b=initdelta)
            torch.nn.init.uniform_(
                self.item_embeddings.weight, a=-initdelta, b=initdelta)
#             torch.nn.init.uniform_(
#                 self.item_bias.weight, a=initdelta, b=-initdelta)
            
    def forward(self, user, item, label):
        pre_logits = (self.user_embeddings(user) \
                              * self.item_embeddings(item)).squeeze().sum(1).view(-1,1) \
                                + self.item_bias(item).view(-1, 1)
        pre_loss =  F.binary_cross_entropy_with_logits(pre_logits, label)
        return pre_loss

    def all_rating(self, user):
        rating = torch.mm(self.user_embeddings(user).view(-1, 5),
                                               self.item_embeddings.weight.t()) + self.item_bias
        return all_rating

    def all_logits(self, user):
        logits = (self.user_embeddings(user) \
                              * self.item_embeddings.weight).sum(1) + self.item_bias
        return all_logits

    def get_reward(self, user, item):
        reward_logits = (self.user_embeddings(user) * self.item_embeddings(item)).squeeze().sum(1).view(-1,1) + self.item_bias(item).view(-1, 1)
        reward = 2 * (torch.sigmoid(reward_logits) - 0.5)
        return reward

# Generator

In [14]:
class Generator(nn.Module):
    
    def __init__(self, item_num, user_num, emb_dim,
                           lamda, param=None, initdelta=0.05, lr=0.05):
        super(Generator, self).__init__()
        self.item_num = item_num
        self.user_num = user_num
        self.emb_dim = emb_dim
        self.lamda = lamda
        self.param = param
        self.initdelta = initdelta
        self.lr = lr
        self.g_params = []
        
        import pickle
        with open("ml-100k/model_dns_ori.pkl", "rb") as f:
            param = pickle.load(f, encoding='latin1')
    
        if self.param is None:
            self.user_embeddings = nn.Embedding(
                self.user_num, self.emb_dim)
            self.item_embeddings = nn.Embedding(
                self.item_num, self.emb_dim)
            self.item_bias = torch.zeros(self.item_num)
            
        self.user_embeddings.weight = torch.nn.Parameter(torch.tensor(param[0]))
        self.item_embeddings.weight = torch.nn.Parameter(torch.tensor(param[1]))
        
            # initialize
#             torch.nn.init.uniform_(
#                 self.user_embeddings.weight, a=-initdelta, b=initdelta)
#             torch.nn.init.uniform_(
#                 self.item_embeddings.weight, a=-initdelta, b=initdelta)
        
    def forward(self, user, item, reward):
        softmax_score = F.softmax(self.all_logits(user).view(1, -1), -1)
        i_prob = torch.gather(softmax_score.view(-1), 0, item).clamp(min=1e-8)
        loss = - torch.mean(torch.log(i_prob) * reward) # \
#                    + self.lamda * (F.normalize(self.user_embeddings(user), p=2, dim=1) \
#                                              + F.normalize(self.item_embeddings(item), p=2, dim=1) \
#                                              + F.normalize(self.item_bias(item), p=2, dim=1))
        return loss

    def all_rating(self, user):
        rating = torch.mm(self.user_embeddings(user).view(-1, 5),
                                               self.item_embeddings.weight.t()) + self.item_bias
        return rating
    
    def all_logits(self, user):
        logits = (self.user_embeddings(user) \
                              * self.item_embeddings.weight).sum(1) + self.item_bias
        return logits

# 評価関数

In [15]:
def dcg_at_k(r, k):
    r = np.asfarray(r)[:k]
    return np.sum(r / np.log2(np.arange(2, r.size + 2)))


def ndcg_at_k(r, k):
    dcg_max = dcg_at_k(sorted(r, reverse=True), k)
    if not dcg_max:
        return 0.
    return dcg_at_k(r, k) / dcg_max

def simple_test_one_user(x):
    
    # import pdb; pdb.set_trace()
    rating = x[0]
    u = x[1]

    test_items = list(all_items - set(user_pos_train[u]))
    item_score = []
    for i in test_items:
        item_score.append((i, rating[i]))

    item_score = sorted(item_score, key=lambda x: x[1])
    item_score.reverse()
    item_sort = [x[0] for x in item_score]

    r = []
    for i in item_sort:
        if i in user_pos_test[u]:
            r.append(1)
        else:
            r.append(0)

    p_3 = np.mean(r[:3])
    p_5 = np.mean(r[:5])
    p_10 = np.mean(r[:10])
    
    ndcg_3 = ndcg_at_k(r, 3)
    ndcg_5 = ndcg_at_k(r, 5)
    ndcg_10 = ndcg_at_k(r, 10)

    return np.array([p_3, p_5, p_10, ndcg_3, ndcg_5, ndcg_10])

def simple_test(model):
    result = np.array([0.] * 6)
    pool = multiprocessing.Pool(cores)
    batch_size = 128
    test_users = list(user_pos_test.keys())
    test_user_num = len(test_users)
    index = 0
    
    while True:
        if index >= test_user_num:
            break
        user_batch = test_users[index:index + batch_size]
        index += batch_size
    
#         print(user_batch)
        user_batch_rating = model.all_rating(torch.tensor(user_batch))

        user_batch_rating = user_batch_rating.detach_().cpu().numpy()

        user_batch_rating_uid = zip(user_batch_rating, user_batch)
        batch_result = pool.map(simple_test_one_user, user_batch_rating_uid)
        for re in batch_result:
            result += re

    pool.close()
    ret = result / test_user_num
    ret = list(ret)
    return ret

# ハイパーパラメーター

In [16]:
EMB_DIM = 5
USER_NUM = 943
ITEM_NUM = 1683
BATCH_SIZE = 16
INIT_DELTA = 0.05

all_items = set(range(ITEM_NUM))
workdir = "ml-100k/"
DIS_TRAIN_FILE = workdir + "dis-train.txt"
cores = multiprocessing.cpu_count()

# ロードデータ

In [17]:
# positiveな要素だけを引っ張ってくる
user_pos_train = {}
with open(workdir + 'movielens-100k-train.txt')as fin:
    for line in fin:
        line = line.split()
        uid = int(line[0])
        iid = int(line[1])
        r = float(line[2])
        if r > 3.99:
            if uid in user_pos_train:
                user_pos_train[uid].append(iid)
            else:
                user_pos_train[uid] = [iid]

# testでnegativeな要素だけを引っ張ってくる                
user_pos_test = {}
with open(workdir + 'movielens-100k-test.txt')as fin:
    for line in fin:
        line = line.split()
        uid = int(line[0])
        iid = int(line[1])
        r = float(line[2])
        if r > 3.99:
            if uid in user_pos_test:
                user_pos_test[uid].append(iid)
            else:
                user_pos_test[uid] = [iid]

all_users = user_pos_train.keys()

# Generatorによるデータセレクト

In [18]:
def generate_for_d(model, filename):
    data = []
    for user in user_pos_train:
        pos = user_pos_train[user]

        rating = model.all_rating(torch.tensor(user))
        rating = rating.detach_().cpu().numpy()
        rating = np.array(rating) / 0.2  # Temperature
        exp_rating = np.exp(rating)
        prob = exp_rating / np.sum(exp_rating)

        neg = np.random.choice(np.arange(ITEM_NUM), size=len(pos), p=prob.reshape(-1,))
        for i in range(len(pos)):
            data.append(str(user) + '\t' + str(pos[i]) + '\t' + str(neg[i]))

    with open(filename, 'w')as fout:
        fout.write('\n'.join(data))

In [19]:
generator = Generator(
    ITEM_NUM, USER_NUM,EMB_DIM, lamda=0.0 / BATCH_SIZE,
    param=None, initdelta=INIT_DELTA)

discriminator = Discriminator(
    ITEM_NUM, USER_NUM,EMB_DIM, lamda=0.0 / BATCH_SIZE,
    param=None, initdelta=INIT_DELTA)

g_optimizer = torch.optim.SGD(
    generator.parameters(), lr=0.001, momentum=0.9)

d_optimizer = torch.optim.SGD(
    discriminator.parameters(), lr=0.001, momentum=0.9)                    

In [20]:
from torch import nn
from torch import autograd

generator.train()
discriminator.train()
for epoch in range(15):
    if epoch >= 0:
        for d_epoch in range(100):
            if d_epoch % 5 == 0:
                generate_for_d(generator, DIS_TRAIN_FILE)
                train_size = ut.file_len(DIS_TRAIN_FILE)
            index = 1
            while True:
                if index > train_size:
                    break
                if index + BATCH_SIZE <= train_size + 1:
                    users, items, labels = ut.get_batch_data(
                        DIS_TRAIN_FILE, index, BATCH_SIZE)
                else:
                    users, items, labels = ut.get_batch_data(
                        DIS_TRAIN_FILE, index, train_size - index + 1)
                    
                index += BATCH_SIZE
                users = torch.tensor(users).view(-1, 1)
                items = torch.tensor(items).view(-1, 1)
                labels = torch.tensor(labels).view(-1, 1)
    
                d_loss = discriminator(users, items, labels)
                d_optimizer.zero_grad()
                d_loss.backward()
                d_optimizer.step()
            print("\r[D Epoch %d/%d] [loss: %f]" %(d_epoch, 100, d_loss.item()))

        for g_epoch in range(50):
            for user in user_pos_train:
                sample_lambda = 0.2
                pos = user_pos_train[user]

                rating = generator.all_logits(torch.tensor(user))
                rating = rating.detach_().cpu().numpy()
                exp_rating = np.exp(rating)
                prob = exp_rating / np.sum(exp_rating)
                
                pn = (1 - sample_lambda) * prob
                pn[pos] += sample_lambda * 1.0 / len(pos)
                
                sample = np.random.choice(range(ITEM_NUM), 2*len(pos), p=pn)
                reward = discriminator.get_reward(torch.tensor([user] * 2 * len(pos)), torch.tensor(sample))
                reward = reward.detach_().cpu().numpy()
                reward = reward * prob[sample] / pn[sample]

                g_loss = generator(torch.tensor(user), torch.tensor(sample), torch.tensor(reward))
                g_optimizer.zero_grad()
                g_loss.backward()
                g_optimizer.step()
                
            print("\r[G Epoch %d/%d] [loss: %f]" %(g_epoch, 50, g_loss.item()))
            result = simple_test(generator)
            print("epoch ", epoch, "gen: ", result)
#             buf = '\t'.join([str(x) for x in result])
#             gen_log.write(str(epoch) + '\t' + buf + '\n')
#             gen_log.flush()

[D Epoch 0/100] [loss: 0.705989]
[D Epoch 1/100] [loss: 0.705915]
[D Epoch 2/100] [loss: 0.705850]
[D Epoch 3/100] [loss: 0.705794]
[D Epoch 4/100] [loss: 0.705746]
[D Epoch 5/100] [loss: 0.705705]
[D Epoch 6/100] [loss: 0.705671]
[D Epoch 7/100] [loss: 0.705643]
[D Epoch 8/100] [loss: 0.705620]
[D Epoch 9/100] [loss: 0.705601]
[D Epoch 10/100] [loss: 0.705586]
[D Epoch 11/100] [loss: 0.705576]
[D Epoch 12/100] [loss: 0.705567]
[D Epoch 13/100] [loss: 0.705562]
[D Epoch 14/100] [loss: 0.705558]
[D Epoch 15/100] [loss: 0.705556]
[D Epoch 16/100] [loss: 0.705555]
[D Epoch 17/100] [loss: 0.705555]
[D Epoch 18/100] [loss: 0.705555]
[D Epoch 19/100] [loss: 0.705555]
[D Epoch 20/100] [loss: 0.705555]
[D Epoch 21/100] [loss: 0.705555]
[D Epoch 22/100] [loss: 0.705554]
[D Epoch 23/100] [loss: 0.705552]
[D Epoch 24/100] [loss: 0.705549]
[D Epoch 25/100] [loss: 0.705545]
[D Epoch 26/100] [loss: 0.705539]
[D Epoch 27/100] [loss: 0.705531]
[D Epoch 28/100] [loss: 0.705522]
[D Epoch 29/100] [loss: 

[G Epoch 28/50] [loss: 0.000518]
epoch  0 gen:  [0.33260233918128684, 0.3065789473684209, 0.27127192982456105, 0.34397383200691178, 0.32675163881732272, 0.31277947942931961]
[G Epoch 29/50] [loss: 0.023846]
epoch  0 gen:  [0.33260233918128684, 0.3065789473684209, 0.27149122807017512, 0.34397383200691178, 0.32678423371278403, 0.31294808022092346]
[G Epoch 30/50] [loss: 0.276036]
epoch  0 gen:  [0.33260233918128684, 0.3065789473684209, 0.27127192982456105, 0.34397383200691178, 0.32678423371278403, 0.31279670340891524]
[G Epoch 31/50] [loss: 0.251526]
epoch  0 gen:  [0.33260233918128684, 0.3065789473684209, 0.27105263157894705, 0.34410857443663651, 0.32688161564720886, 0.31275456607097391]
[G Epoch 32/50] [loss: 0.460024]
epoch  0 gen:  [0.33260233918128684, 0.3065789473684209, 0.27105263157894705, 0.34397383200691178, 0.32678423371278403, 0.31265718413654908]
[G Epoch 33/50] [loss: 0.599711]
epoch  0 gen:  [0.33260233918128684, 0.3065789473684209, 0.27105263157894705, 0.34397383200691178

[G Epoch 6/50] [loss: 0.053234]
epoch  1 gen:  [0.33260233918128684, 0.30614035087719282, 0.27149122807017512, 0.34424331686636128, 0.32672386204395393, 0.3131106435430232]
[G Epoch 7/50] [loss: 0.203388]
epoch  1 gen:  [0.33260233918128684, 0.30614035087719282, 0.27127192982456105, 0.34424331686636128, 0.32672386204395393, 0.31297809137259214]
[G Epoch 8/50] [loss: 0.462276]
epoch  1 gen:  [0.33260233918128684, 0.3065789473684209, 0.27127192982456105, 0.34424331686636128, 0.32701159247709505, 0.31296737964961963]
[G Epoch 9/50] [loss: -0.239117]
epoch  1 gen:  [0.33260233918128684, 0.3065789473684209, 0.27127192982456105, 0.34424331686636128, 0.32701159247709505, 0.31297894850350916]
[G Epoch 10/50] [loss: -0.109226]
epoch  1 gen:  [0.33260233918128684, 0.30614035087719282, 0.27149122807017512, 0.34424331686636128, 0.32672386204395393, 0.31310201920056141]
[G Epoch 11/50] [loss: 0.130236]
epoch  1 gen:  [0.33260233918128684, 0.30614035087719282, 0.27149122807017512, 0.3442433168663612

[D Epoch 15/100] [loss: 0.684988]
[D Epoch 16/100] [loss: 0.684898]
[D Epoch 17/100] [loss: 0.684809]
[D Epoch 18/100] [loss: 0.684721]
[D Epoch 19/100] [loss: 0.684634]
[D Epoch 20/100] [loss: 0.684547]
[D Epoch 21/100] [loss: 0.684461]
[D Epoch 22/100] [loss: 0.684376]
[D Epoch 23/100] [loss: 0.684292]
[D Epoch 24/100] [loss: 0.684208]
[D Epoch 25/100] [loss: 0.684125]
[D Epoch 26/100] [loss: 0.684043]
[D Epoch 27/100] [loss: 0.683961]
[D Epoch 28/100] [loss: 0.683881]
[D Epoch 29/100] [loss: 0.683800]
[D Epoch 30/100] [loss: 0.683721]
[D Epoch 31/100] [loss: 0.683642]
[D Epoch 32/100] [loss: 0.683564]
[D Epoch 33/100] [loss: 0.683487]
[D Epoch 34/100] [loss: 0.683410]
[D Epoch 35/100] [loss: 0.683335]
[D Epoch 36/100] [loss: 0.683259]
[D Epoch 37/100] [loss: 0.683185]
[D Epoch 38/100] [loss: 0.683111]
[D Epoch 39/100] [loss: 0.683038]
[D Epoch 40/100] [loss: 0.682966]
[D Epoch 41/100] [loss: 0.682894]
[D Epoch 42/100] [loss: 0.682823]
[D Epoch 43/100] [loss: 0.682753]
[D Epoch 44/10

[G Epoch 31/50] [loss: -0.236862]
epoch  2 gen:  [0.33406432748538045, 0.30614035087719282, 0.27039473684210497, 0.3450029520201196, 0.32682778927065331, 0.31242264620268889]
[G Epoch 32/50] [loss: -0.250584]
epoch  2 gen:  [0.33406432748538045, 0.30614035087719282, 0.2701754385964909, 0.3450029520201196, 0.32682778927065331, 0.31229175127278452]
[G Epoch 33/50] [loss: -0.187646]
epoch  2 gen:  [0.33406432748538045, 0.30614035087719282, 0.27039473684210497, 0.3450029520201196, 0.32682778927065331, 0.3124140218602271]
[G Epoch 34/50] [loss: -0.265352]
epoch  2 gen:  [0.33406432748538045, 0.30614035087719282, 0.2701754385964909, 0.34527243687956904, 0.32702255313950301, 0.31243507873440091]
[G Epoch 35/50] [loss: -0.542150]
epoch  2 gen:  [0.33406432748538045, 0.30614035087719282, 0.27061403508771897, 0.34527243687956904, 0.32702255313950301, 0.31268658049629955]
[G Epoch 36/50] [loss: 0.045144]
epoch  2 gen:  [0.33406432748538045, 0.30614035087719282, 0.2701754385964909, 0.3451376944498

[G Epoch 9/50] [loss: -0.147155]
epoch  3 gen:  [0.33479532163742726, 0.30614035087719282, 0.27017543859649096, 0.34565225445644826, 0.32697673206646877, 0.3124133382354678]
[G Epoch 10/50] [loss: -0.406682]
epoch  3 gen:  [0.33479532163742726, 0.30614035087719282, 0.27017543859649096, 0.34565225445644826, 0.32700932696193008, 0.31243946991818972]
[G Epoch 11/50] [loss: -0.762689]
epoch  3 gen:  [0.33479532163742726, 0.30614035087719282, 0.27017543859649096, 0.34565225445644826, 0.32697673206646877, 0.31240851499194733]
[G Epoch 12/50] [loss: -0.319773]
epoch  3 gen:  [0.33479532163742726, 0.30614035087719282, 0.27017543859649096, 0.34565225445644826, 0.32697673206646877, 0.31239790337341294]
[G Epoch 13/50] [loss: -0.516966]
epoch  3 gen:  [0.33479532163742726, 0.30614035087719282, 0.27017543859649096, 0.34551751202672354, 0.32691194502750526, 0.31238840682799396]
[G Epoch 14/50] [loss: -0.231970]
epoch  3 gen:  [0.33479532163742726, 0.30614035087719282, 0.27039473684210497, 0.3455175

[D Epoch 30/100] [loss: 0.679052]
[D Epoch 31/100] [loss: 0.679066]
[D Epoch 32/100] [loss: 0.679081]
[D Epoch 33/100] [loss: 0.679096]
[D Epoch 34/100] [loss: 0.679111]
[D Epoch 35/100] [loss: 0.679127]
[D Epoch 36/100] [loss: 0.679142]
[D Epoch 37/100] [loss: 0.679158]
[D Epoch 38/100] [loss: 0.679174]
[D Epoch 39/100] [loss: 0.679190]
[D Epoch 40/100] [loss: 0.679206]
[D Epoch 41/100] [loss: 0.679223]
[D Epoch 42/100] [loss: 0.679240]
[D Epoch 43/100] [loss: 0.679257]
[D Epoch 44/100] [loss: 0.679274]
[D Epoch 45/100] [loss: 0.679291]
[D Epoch 46/100] [loss: 0.679308]
[D Epoch 47/100] [loss: 0.679326]
[D Epoch 48/100] [loss: 0.679344]
[D Epoch 49/100] [loss: 0.679362]
[D Epoch 50/100] [loss: 0.679380]
[D Epoch 51/100] [loss: 0.679399]
[D Epoch 52/100] [loss: 0.679417]
[D Epoch 53/100] [loss: 0.679436]
[D Epoch 54/100] [loss: 0.679454]
[D Epoch 55/100] [loss: 0.679473]
[D Epoch 56/100] [loss: 0.679492]
[D Epoch 57/100] [loss: 0.679512]
[D Epoch 58/100] [loss: 0.679531]
[D Epoch 59/10

[G Epoch 34/50] [loss: -0.328392]
epoch  4 gen:  [0.33552631578947406, 0.30614035087719277, 0.27149122807017517, 0.34666307408044456, 0.32749195764979039, 0.31435859454836268]
[G Epoch 35/50] [loss: -1.172761]
epoch  4 gen:  [0.33552631578947406, 0.30614035087719277, 0.27149122807017517, 0.34679781651016928, 0.32758933958421521, 0.31444559223311114]
[G Epoch 36/50] [loss: -0.717107]
epoch  4 gen:  [0.33552631578947406, 0.30570175438596475, 0.27149122807017517, 0.34666307408044456, 0.32720422721664932, 0.31435486450695915]
[G Epoch 37/50] [loss: -0.298549]
epoch  4 gen:  [0.33552631578947406, 0.30570175438596475, 0.27149122807017517, 0.34679781651016928, 0.32733420404653546, 0.31442403144576231]
[G Epoch 38/50] [loss: -0.091307]
epoch  4 gen:  [0.33552631578947406, 0.30570175438596475, 0.27149122807017517, 0.34666307408044456, 0.32723682211211064, 0.31437601636381873]
[G Epoch 39/50] [loss: -1.031968]
epoch  4 gen:  [0.33552631578947406, 0.30570175438596475, 0.27171052631578918, 0.34679

[G Epoch 12/50] [loss: -0.602750]
epoch  5 gen:  [0.33625730994152087, 0.30526315789473668, 0.27127192982456111, 0.34796167895310182, 0.32764719897268446, 0.31478388464858986]
[G Epoch 13/50] [loss: -0.903650]
epoch  5 gen:  [0.33625730994152087, 0.30526315789473668, 0.27127192982456111, 0.3482067541002562, 0.32775913150290514, 0.31485185547976557]
[G Epoch 14/50] [loss: -0.484257]
epoch  5 gen:  [0.33625730994152087, 0.30482456140350866, 0.27149122807017517, 0.34782693652337709, 0.32719689681419589, 0.31479834311137517]
[G Epoch 15/50] [loss: -0.896703]
epoch  5 gen:  [0.33552631578947406, 0.30438596491228065, 0.27171052631578918, 0.34769219409365232, 0.32713210977523244, 0.31508732828541264]
[G Epoch 16/50] [loss: -0.840670]
epoch  5 gen:  [0.33552631578947406, 0.30482456140350866, 0.27127192982456111, 0.34769219409365232, 0.32738724531291225, 0.31477553152515114]
[G Epoch 17/50] [loss: -0.794626]
epoch  5 gen:  [0.33552631578947406, 0.30482456140350866, 0.27127192982456111, 0.347692

[D Epoch 45/100] [loss: 0.684293]
[D Epoch 46/100] [loss: 0.684318]
[D Epoch 47/100] [loss: 0.684342]
[D Epoch 48/100] [loss: 0.684366]
[D Epoch 49/100] [loss: 0.684390]
[D Epoch 50/100] [loss: 0.684414]
[D Epoch 51/100] [loss: 0.684439]
[D Epoch 52/100] [loss: 0.684463]
[D Epoch 53/100] [loss: 0.684487]
[D Epoch 54/100] [loss: 0.684510]
[D Epoch 55/100] [loss: 0.684534]
[D Epoch 56/100] [loss: 0.684558]
[D Epoch 57/100] [loss: 0.684582]
[D Epoch 58/100] [loss: 0.684605]
[D Epoch 59/100] [loss: 0.684629]
[D Epoch 60/100] [loss: 0.684652]
[D Epoch 61/100] [loss: 0.684675]
[D Epoch 62/100] [loss: 0.684699]
[D Epoch 63/100] [loss: 0.684722]
[D Epoch 64/100] [loss: 0.684745]
[D Epoch 65/100] [loss: 0.684768]
[D Epoch 66/100] [loss: 0.684790]
[D Epoch 67/100] [loss: 0.684813]
[D Epoch 68/100] [loss: 0.684836]
[D Epoch 69/100] [loss: 0.684858]
[D Epoch 70/100] [loss: 0.684881]
[D Epoch 71/100] [loss: 0.684903]
[D Epoch 72/100] [loss: 0.684925]
[D Epoch 73/100] [loss: 0.684947]
[D Epoch 74/10

epoch  6 gen:  [0.33625730994152087, 0.3065789473684209, 0.27192982456140319, 0.34912554139603447, 0.32921639940483477, 0.3156529980370662]
[G Epoch 37/50] [loss: -1.037281]
epoch  6 gen:  [0.33625730994152087, 0.3065789473684209, 0.2721491228070172, 0.34912554139603447, 0.32921639940483477, 0.3157808937240395]
[G Epoch 38/50] [loss: -0.692177]
epoch  6 gen:  [0.33625730994152087, 0.3065789473684209, 0.27192982456140319, 0.34912554139603441, 0.32921639940483477, 0.31565938178799302]
[G Epoch 39/50] [loss: -0.729020]
epoch  6 gen:  [0.33552631578947406, 0.30701754385964902, 0.2721491228070172, 0.34861098138943053, 0.32948516387204663, 0.31584888987943582]
[G Epoch 40/50] [loss: -0.896365]
epoch  6 gen:  [0.33552631578947406, 0.30701754385964902, 0.2721491228070172, 0.34809642138282665, 0.32914587257751499, 0.31561353391500746]
[G Epoch 41/50] [loss: -1.030525]
epoch  6 gen:  [0.33552631578947406, 0.30701754385964902, 0.2721491228070172, 0.34809642138282665, 0.32914587257751499, 0.315602

[G Epoch 14/50] [loss: -0.578146]
epoch  7 gen:  [0.33552631578947406, 0.3065789473684209, 0.2721491228070172, 0.34861098138943053, 0.32913224364798288, 0.31575146648833036]
[G Epoch 15/50] [loss: -1.397089]
epoch  7 gen:  [0.33552631578947406, 0.3065789473684209, 0.2721491228070172, 0.34847623895970575, 0.32903486171355806, 0.31568130517428017]
[G Epoch 16/50] [loss: -0.973129]
epoch  7 gen:  [0.33552631578947406, 0.3065789473684209, 0.27192982456140319, 0.34847623895970575, 0.32903486171355806, 0.31553316155945221]
[G Epoch 17/50] [loss: -1.229291]
epoch  7 gen:  [0.33552631578947406, 0.3065789473684209, 0.2721491228070172, 0.34847623895970575, 0.32906745660901937, 0.31568063456487372]
[G Epoch 18/50] [loss: -1.449700]
epoch  7 gen:  [0.33552631578947406, 0.3065789473684209, 0.27236842105263126, 0.34847623895970575, 0.32903486171355806, 0.31581391142221416]
[G Epoch 19/50] [loss: -0.998334]
epoch  7 gen:  [0.33552631578947406, 0.3065789473684209, 0.27236842105263126, 0.34847623895970

[D Epoch 56/100] [loss: 0.686328]
[D Epoch 57/100] [loss: 0.686315]
[D Epoch 58/100] [loss: 0.686302]
[D Epoch 59/100] [loss: 0.686289]
[D Epoch 60/100] [loss: 0.686275]
[D Epoch 61/100] [loss: 0.686261]
[D Epoch 62/100] [loss: 0.686247]
[D Epoch 63/100] [loss: 0.686232]
[D Epoch 64/100] [loss: 0.686217]
[D Epoch 65/100] [loss: 0.686202]
[D Epoch 66/100] [loss: 0.686186]
[D Epoch 67/100] [loss: 0.686170]
[D Epoch 68/100] [loss: 0.686154]
[D Epoch 69/100] [loss: 0.686138]
[D Epoch 70/100] [loss: 0.686121]
[D Epoch 71/100] [loss: 0.686104]
[D Epoch 72/100] [loss: 0.686086]
[D Epoch 73/100] [loss: 0.686069]
[D Epoch 74/100] [loss: 0.686050]
[D Epoch 75/100] [loss: 0.686032]
[D Epoch 76/100] [loss: 0.686013]
[D Epoch 77/100] [loss: 0.685994]
[D Epoch 78/100] [loss: 0.685975]
[D Epoch 79/100] [loss: 0.685955]
[D Epoch 80/100] [loss: 0.685935]
[D Epoch 81/100] [loss: 0.685915]
[D Epoch 82/100] [loss: 0.685894]
[D Epoch 83/100] [loss: 0.685873]
[D Epoch 84/100] [loss: 0.685851]
[D Epoch 85/10

[G Epoch 39/50] [loss: -1.191746]
epoch  8 gen:  [0.33552631578947401, 0.30482456140350861, 0.27171052631578912, 0.34847623895970564, 0.32742109368114458, 0.31573242567014786]
[G Epoch 40/50] [loss: -1.137964]
epoch  8 gen:  [0.33552631578947401, 0.30482456140350861, 0.27192982456140319, 0.34834149652998092, 0.32732371174671976, 0.31579931835097658]
[G Epoch 41/50] [loss: -0.812121]
epoch  8 gen:  [0.33552631578947401, 0.30482456140350861, 0.27171052631578912, 0.34861098138943042, 0.32745805142448536, 0.31578635048749365]
[G Epoch 42/50] [loss: -1.561078]
epoch  8 gen:  [0.33552631578947401, 0.30482456140350861, 0.27149122807017512, 0.34861098138943042, 0.32749064631994668, 0.31543902448593786]
[G Epoch 43/50] [loss: -0.947068]
epoch  8 gen:  [0.33552631578947401, 0.30482456140350861, 0.27171052631578912, 0.34874572381915514, 0.3275880282543715, 0.31588092427911196]
[G Epoch 44/50] [loss: -1.194107]
epoch  8 gen:  [0.33552631578947401, 0.30482456140350861, 0.27192982456140319, 0.348745

[G Epoch 17/50] [loss: -0.826480]
epoch  9 gen:  [0.33552631578947401, 0.30482456140350861, 0.27149122807017512, 0.34861098138943042, 0.32745328582464678, 0.31557817661356469]
[G Epoch 18/50] [loss: -0.765723]
epoch  9 gen:  [0.33552631578947401, 0.30482456140350861, 0.27192982456140319, 0.34874572381915514, 0.32755066775907155, 0.31591326687433585]
[G Epoch 19/50] [loss: -0.637187]
epoch  9 gen:  [0.33552631578947401, 0.30482456140350861, 0.27192982456140319, 0.34874572381915514, 0.32755066775907155, 0.31591349634597993]
[G Epoch 20/50] [loss: -1.176858]
epoch  9 gen:  [0.33552631578947401, 0.30438596491228059, 0.27192982456140319, 0.34861098138943042, 0.3271655553915056, 0.31585635802737988]
[G Epoch 21/50] [loss: -1.607619]
epoch  9 gen:  [0.33479532163742715, 0.30482456140350861, 0.27192982456140319, 0.34809642138282643, 0.3274451779034217, 0.3158507914081205]
[G Epoch 22/50] [loss: -0.790557]
epoch  9 gen:  [0.33479532163742715, 0.30438596491228059, 0.27192982456140319, 0.34809642

[D Epoch 71/100] [loss: 0.673754]
[D Epoch 72/100] [loss: 0.673633]
[D Epoch 73/100] [loss: 0.673512]
[D Epoch 74/100] [loss: 0.673389]
[D Epoch 75/100] [loss: 0.673267]
[D Epoch 76/100] [loss: 0.673143]
[D Epoch 77/100] [loss: 0.673019]
[D Epoch 78/100] [loss: 0.672893]
[D Epoch 79/100] [loss: 0.672767]
[D Epoch 80/100] [loss: 0.672641]
[D Epoch 81/100] [loss: 0.672513]
[D Epoch 82/100] [loss: 0.672385]
[D Epoch 83/100] [loss: 0.672256]
[D Epoch 84/100] [loss: 0.672126]
[D Epoch 85/100] [loss: 0.671996]
[D Epoch 86/100] [loss: 0.671865]
[D Epoch 87/100] [loss: 0.671732]
[D Epoch 88/100] [loss: 0.671599]
[D Epoch 89/100] [loss: 0.671466]
[D Epoch 90/100] [loss: 0.671331]
[D Epoch 91/100] [loss: 0.671196]
[D Epoch 92/100] [loss: 0.671060]
[D Epoch 93/100] [loss: 0.670923]
[D Epoch 94/100] [loss: 0.670785]
[D Epoch 95/100] [loss: 0.670647]
[D Epoch 96/100] [loss: 0.670507]
[D Epoch 97/100] [loss: 0.670367]
[D Epoch 98/100] [loss: 0.670227]
[D Epoch 99/100] [loss: 0.670085]
[G Epoch 0/50]

[G Epoch 41/50] [loss: -0.814039]
epoch  10 gen:  [0.33552631578947406, 0.30614035087719282, 0.27061403508771897, 0.34858657167713525, 0.32839246262720989, 0.31498158841906193]
[G Epoch 42/50] [loss: -0.726235]
epoch  10 gen:  [0.33552631578947406, 0.30614035087719282, 0.27061403508771897, 0.34858657167713525, 0.32839246262720989, 0.31497054818467751]
[G Epoch 43/50] [loss: -0.844184]
epoch  10 gen:  [0.33552631578947406, 0.30614035087719282, 0.27083333333333304, 0.34858657167713525, 0.32839246262720989, 0.31512651397283686]
[G Epoch 44/50] [loss: -0.964462]
epoch  10 gen:  [0.33552631578947406, 0.30614035087719282, 0.27127192982456111, 0.34858657167713525, 0.3284250575226712, 0.31541437450940429]
[G Epoch 45/50] [loss: -1.078469]
epoch  10 gen:  [0.33552631578947406, 0.30614035087719282, 0.27105263157894705, 0.34858657167713525, 0.32839246262720989, 0.3152676695285434]
[G Epoch 46/50] [loss: -1.029049]
epoch  10 gen:  [0.33552631578947406, 0.30614035087719282, 0.27105263157894705, 0.3

[G Epoch 19/50] [loss: -1.482238]
epoch  11 gen:  [0.33552631578947395, 0.30657894736842084, 0.27039473684210491, 0.34872131410686008, 0.32887535968115988, 0.31471475698942042]
[G Epoch 20/50] [loss: -1.512121]
epoch  11 gen:  [0.3347953216374272, 0.30614035087719277, 0.27039473684210491, 0.34820675410025614, 0.32853606838662819, 0.31464414612781622]
[G Epoch 21/50] [loss: -1.322604]
epoch  11 gen:  [0.3347953216374272, 0.30614035087719277, 0.27039473684210491, 0.34820675410025614, 0.32853606838662819, 0.31463140394856653]
[G Epoch 22/50] [loss: -0.922538]
epoch  11 gen:  [0.3347953216374272, 0.30614035087719277, 0.27105263157894705, 0.34820675410025614, 0.32853606838662819, 0.31506446141676187]
[G Epoch 23/50] [loss: -1.327455]
epoch  11 gen:  [0.3347953216374272, 0.3057017543859647, 0.27105263157894705, 0.34820675410025614, 0.32828093284894838, 0.31510234987696673]
[G Epoch 24/50] [loss: -0.839654]
epoch  11 gen:  [0.3340643274853804, 0.3057017543859647, 0.27127192982456111, 0.347692

[D Epoch 81/100] [loss: 0.629684]
[D Epoch 82/100] [loss: 0.629380]
[D Epoch 83/100] [loss: 0.629075]
[D Epoch 84/100] [loss: 0.628769]
[D Epoch 85/100] [loss: 0.628462]
[D Epoch 86/100] [loss: 0.628155]
[D Epoch 87/100] [loss: 0.627847]
[D Epoch 88/100] [loss: 0.627537]
[D Epoch 89/100] [loss: 0.627227]
[D Epoch 90/100] [loss: 0.626917]
[D Epoch 91/100] [loss: 0.626605]
[D Epoch 92/100] [loss: 0.626293]
[D Epoch 93/100] [loss: 0.625979]
[D Epoch 94/100] [loss: 0.625665]
[D Epoch 95/100] [loss: 0.625350]
[D Epoch 96/100] [loss: 0.625034]
[D Epoch 97/100] [loss: 0.624718]
[D Epoch 98/100] [loss: 0.624400]
[D Epoch 99/100] [loss: 0.624082]
[G Epoch 0/50] [loss: -1.370675]
epoch  12 gen:  [0.33552631578947406, 0.30526315789473663, 0.27127192982456111, 0.34899079896630952, 0.32814174245966354, 0.31565559716967667]
[G Epoch 1/50] [loss: -1.130915]
epoch  12 gen:  [0.33552631578947406, 0.30482456140350861, 0.27192982456140319, 0.34899079896630952, 0.32788660692198379, 0.31608913985130921]
[G

[G Epoch 43/50] [loss: -0.793249]
epoch  12 gen:  [0.3369883040935675, 0.30526315789473663, 0.27149122807017501, 0.34975043412006795, 0.32794755002737264, 0.31522401510638676]
[G Epoch 44/50] [loss: -1.124830]
epoch  12 gen:  [0.33625730994152075, 0.30482456140350861, 0.27149122807017501, 0.34923587411346402, 0.32757566383737963, 0.31513261186405633]
[G Epoch 45/50] [loss: -1.349288]
epoch  12 gen:  [0.33625730994152075, 0.30394736842105252, 0.271271929824561, 0.34923587411346402, 0.32706539276202001, 0.31499683658672073]
[G Epoch 46/50] [loss: -1.078575]
epoch  12 gen:  [0.33552631578947395, 0.30438596491228054, 0.27149122807017501, 0.34872131410686014, 0.32723637254284788, 0.31507165381630542]
[G Epoch 47/50] [loss: -1.218204]
epoch  12 gen:  [0.33552631578947395, 0.30438596491228054, 0.27149122807017501, 0.34885605653658486, 0.32740370986803385, 0.31526235315484297]
[G Epoch 48/50] [loss: -0.942908]
epoch  12 gen:  [0.33552631578947395, 0.30438596491228054, 0.27149122807017506, 0.34

[G Epoch 21/50] [loss: -1.222988]
epoch  13 gen:  [0.33771929824561436, 0.30350877192982445, 0.27192982456140319, 0.35039973655639656, 0.32701076088154624, 0.31565303020249136]
[G Epoch 22/50] [loss: -0.925465]
epoch  13 gen:  [0.33771929824561436, 0.30394736842105252, 0.27171052631578912, 0.35039973655639656, 0.3272984913146873, 0.31551035473066114]
[G Epoch 23/50] [loss: -0.875251]
epoch  13 gen:  [0.33771929824561436, 0.30394736842105252, 0.27149122807017512, 0.35039973655639656, 0.3272984913146873, 0.31537173775693295]
[G Epoch 24/50] [loss: -0.888614]
epoch  13 gen:  [0.33845029239766111, 0.30394736842105252, 0.27171052631578912, 0.35091429656300049, 0.32738264707153925, 0.315579757754885]
[G Epoch 25/50] [loss: -1.601693]
epoch  13 gen:  [0.33845029239766111, 0.30394736842105252, 0.27149122807017512, 0.35091429656300049, 0.32738264707153925, 0.31544023848251884]
[G Epoch 26/50] [loss: -1.118327]
epoch  13 gen:  [0.33845029239766111, 0.30350877192982445, 0.27171052631578912, 0.350

[D Epoch 91/100] [loss: 0.553042]
[D Epoch 92/100] [loss: 0.552650]
[D Epoch 93/100] [loss: 0.552257]
[D Epoch 94/100] [loss: 0.551865]
[D Epoch 95/100] [loss: 0.551473]
[D Epoch 96/100] [loss: 0.551080]
[D Epoch 97/100] [loss: 0.550688]
[D Epoch 98/100] [loss: 0.550296]
[D Epoch 99/100] [loss: 0.549904]
[G Epoch 0/50] [loss: -1.818725]
epoch  14 gen:  [0.33625730994152081, 0.30350877192982445, 0.2723684210526312, 0.34937061654318874, 0.3268424493678424, 0.31612071050909968]
[G Epoch 1/50] [loss: -1.414334]
epoch  14 gen:  [0.33698830409356756, 0.30350877192982445, 0.27258771929824527, 0.34988517654979268, 0.32692660512469435, 0.31635737702961364]
[G Epoch 2/50] [loss: -0.606696]
epoch  14 gen:  [0.33625730994152081, 0.30394736842105252, 0.27258771929824527, 0.34937061654318874, 0.32709758490552227, 0.31627404512386476]
[G Epoch 3/50] [loss: -1.037645]
epoch  14 gen:  [0.33625730994152081, 0.30438596491228054, 0.2723684210526312, 0.34937061654318874, 0.32732012554774065, 0.315966299133

[G Epoch 45/50] [loss: -0.706610]
epoch  14 gen:  [0.33625730994152087, 0.30219298245614024, 0.27171052631578907, 0.34896638925401446, 0.32571970716060594, 0.31530269145437273]
[G Epoch 46/50] [loss: -0.977212]
epoch  14 gen:  [0.33625730994152087, 0.30219298245614024, 0.27214912280701714, 0.34896638925401446, 0.32571970716060594, 0.31561271817301706]
[G Epoch 47/50] [loss: -1.039399]
epoch  14 gen:  [0.33698830409356761, 0.30219298245614024, 0.2723684210526312, 0.34948094926061846, 0.32577126802199646, 0.31576840685370394]
[G Epoch 48/50] [loss: -1.536139]
epoch  14 gen:  [0.33698830409356761, 0.30219298245614024, 0.2723684210526312, 0.34948094926061846, 0.32577126802199646, 0.315768406853704]
[G Epoch 49/50] [loss: -1.620250]
epoch  14 gen:  [0.33625730994152087, 0.30219298245614024, 0.27258771929824521, 0.34896638925401446, 0.32571970716060594, 0.31586573518093919]


In [21]:
generator.user_embeddings.weight = torch.nn.Parameter(torch.tensor(param[0]))
generator.item_embeddings.weight = torch.nn.Parameter(torch.tensor(param[1]))

NameError: name 'param' is not defined