# ライブラリのインポート

In [68]:
import numpy as np
import utils as ut
import torch.nn as nn
import torch.nn.functional as F
import torch
import torchvision.models as models
import multiprocessing

# Discriminator

In [89]:
class Discriminator(nn.Module):
    
    def __init__(self, item_num, user_num, emb_dim,
                           lamda, param=None, initdelta=0.05):
        super(Discriminator, self).__init__()
        self.item_num = item_num
        self.user_num = user_num
        self.emb_dim = emb_dim
        self.lamda = lamda
        self.param = param
        self.initdelta = initdelta
        self.d_params = []
        
        if self.param is None:
            self.user_embeddings = nn.Embedding(
                self.user_num, self.emb_dim)            
            self.item_embeddings = nn.Embedding(
                self.item_num, self.emb_dim)            
            self.item_bias = nn.Embedding(
                self.item_num, 1)
            
            # initialize
            torch.nn.init.uniform_(
                self.user_embeddings.weight, a=-initdelta, b=initdelta)
            torch.nn.init.uniform_(
                self.item_embeddings.weight, a=-initdelta, b=initdelta)
#             torch.nn.init.uniform_(
#                 self.item_bias.weight, a=initdelta, b=-initdelta)
            
    def forward(self, user, item, label):
        pre_logits = (self.user_embeddings(user) \
                              * self.item_embeddings(item)).squeeze().sum(1).view(-1,1) \
                                + self.item_bias(item).view(-1, 1)
        pre_loss =  F.binary_cross_entropy_with_logits(pre_logits, label)
        return pre_loss

    def all_rating(self, user):
        rating = torch.mm(self.user_embeddings(user).view(-1, 5),
                                               self.item_embeddings.weight.t()) + self.item_bias
        return all_rating

    def all_logits(self, user):
        logits = (self.user_embeddings(user) \
                              * self.item_embeddings.weight).sum(1) + self.item_bias
        return all_logits

    def get_reward(self, user, item):
        reward_logits = (self.user_embeddings(user) * self.item_embeddings(item)).squeeze().sum(1).view(-1,1) + self.item_bias(item).view(-1, 1)
        reward = 2 * (torch.sigmoid(reward_logits) - 0.5)
        return reward

# Generator

In [117]:
class Generator(nn.Module):
    
    def __init__(self, item_num, user_num, emb_dim,
                           lamda, param=None, initdelta=0.05, lr=0.05):
        super(Generator, self).__init__()
        self.item_num = item_num
        self.user_num = user_num
        self.emb_dim = emb_dim
        self.lamda = lamda
        self.param = param
        self.initdelta = initdelta
        self.lr = lr
        self.g_params = []
        
        import pickle
        with open("ml-100k/model_dns_ori.pkl", "rb") as f:
            param = pickle.load(f, encoding='latin1')
    
        if self.param is None:
            self.user_embeddings = nn.Embedding(
                self.user_num, self.emb_dim)
            self.item_embeddings = nn.Embedding(
                self.item_num, self.emb_dim)
            self.item_bias = torch.zeros(self.item_num)
            
        self.user_embeddings.weight = torch.nn.Parameter(torch.tensor(param[0]))
        self.item_embeddings.weight = torch.nn.Parameter(torch.tensor(param[1]))
        
            # initialize
#             torch.nn.init.uniform_(
#                 self.user_embeddings.weight, a=-initdelta, b=initdelta)
#             torch.nn.init.uniform_(
#                 self.item_embeddings.weight, a=-initdelta, b=initdelta)
        
    def forward(self, user, item, reward):
        softmax_score = F.softmax(self.all_logits(user).view(1, -1), -1)
        i_prob = torch.gather(softmax_score.view(-1), 0, item).clamp(min=1e-8)
        loss = - torch.mean(torch.log(i_prob) * reward) # \
#                    + self.lamda * (F.normalize(self.user_embeddings(user), p=2, dim=1) \
#                                              + F.normalize(self.item_embeddings(item), p=2, dim=1) \
#                                              + F.normalize(self.item_bias(item), p=2, dim=1))
        return loss

    def all_rating(self, user):
        rating = torch.mm(self.user_embeddings(user).view(-1, 5),
                                               self.item_embeddings.weight.t()) + self.item_bias
        return rating
    
    def all_logits(self, user):
        logits = (self.user_embeddings(user) \
                              * self.item_embeddings.weight).sum(1) + self.item_bias
        return logits

# 評価関数

In [118]:
def dcg_at_k(r, k):
    r = np.asfarray(r)[:k]
    return np.sum(r / np.log2(np.arange(2, r.size + 2)))


def ndcg_at_k(r, k):
    dcg_max = dcg_at_k(sorted(r, reverse=True), k)
    if not dcg_max:
        return 0.
    return dcg_at_k(r, k) / dcg_max

def simple_test_one_user(x):
    # import pdb; pdb.set_trace()
    rating = x[0]
    u = x[1]

    test_items = list(all_items - set(user_pos_train[u]))
    item_score = []
    for i in test_items:
        item_score.append((i, rating[i]))

    item_score = sorted(item_score, key=lambda x: x[1])
    item_score.reverse()
    item_sort = [x[0] for x in item_score]

    r = []
    for i in item_sort:
        if i in user_pos_test[u]:
            r.append(1)
        else:
            r.append(0)

    p_3 = np.mean(r[:3])
    p_5 = np.mean(r[:5])
    p_10 = np.mean(r[:10])
    ndcg_3 = ndcg_at_k(r, 3)
    ndcg_5 = ndcg_at_k(r, 5)
    ndcg_10 = ndcg_at_k(r, 10)

    return np.array([p_3, p_5, p_10, ndcg_3, ndcg_5, ndcg_10])

def simple_test(model):
    result = np.array([0.] * 6)
    pool = multiprocessing.Pool(cores)
    batch_size = 128
    test_users = list(user_pos_test.keys())
    test_user_num = len(test_users)
    index = 0
    while True:
        if index >= test_user_num:
            break
        user_batch = test_users[index:index + batch_size]
        index += batch_size
    
#         print(user_batch)
        user_batch_rating = model.all_rating(torch.tensor(user_batch))

        user_batch_rating = user_batch_rating.detach_().cpu().numpy()

        user_batch_rating_uid = zip(user_batch_rating, user_batch)
        batch_result = pool.map(simple_test_one_user, user_batch_rating_uid)
        for re in batch_result:
            result += re

    pool.close()
    ret = result / test_user_num
    ret = list(ret)
    return ret

# ハイパーパラメーター

In [119]:
EMB_DIM = 5
USER_NUM = 943
ITEM_NUM = 1683
BATCH_SIZE = 16
INIT_DELTA = 0.05

all_items = set(range(ITEM_NUM))
workdir = "ml-100k/"
DIS_TRAIN_FILE = workdir + "dis-train.txt"
cores = multiprocessing.cpu_count()

# ロードデータ

In [120]:
# positiveな要素だけを引っ張ってくる
user_pos_train = {}
with open(workdir + 'movielens-100k-train.txt')as fin:
    for line in fin:
        line = line.split()
        uid = int(line[0])
        iid = int(line[1])
        r = float(line[2])
        if r > 3.99:
            if uid in user_pos_train:
                user_pos_train[uid].append(iid)
            else:
                user_pos_train[uid] = [iid]

# testでnegativeな要素だけを引っ張ってくる                
user_pos_test = {}
with open(workdir + 'movielens-100k-test.txt')as fin:
    for line in fin:
        line = line.split()
        uid = int(line[0])
        iid = int(line[1])
        r = float(line[2])
        if r > 3.99:
            if uid in user_pos_test:
                user_pos_test[uid].append(iid)
            else:
                user_pos_test[uid] = [iid]

all_users = user_pos_train.keys()

# Generatorによるデータセレクト

In [121]:
def generate_for_d(model, filename):
    data = []
    for user in user_pos_train:
        pos = user_pos_train[user]

        rating = model.all_rating(torch.tensor(user))
        rating = rating.detach_().cpu().numpy()
        rating = np.array(rating) / 0.2  # Temperature
        exp_rating = np.exp(rating)
        prob = exp_rating / np.sum(exp_rating)

        neg = np.random.choice(np.arange(ITEM_NUM), size=len(pos), p=prob.reshape(-1,))
        for i in range(len(pos)):
            data.append(str(user) + '\t' + str(pos[i]) + '\t' + str(neg[i]))

    with open(filename, 'w')as fout:
        fout.write('\n'.join(data))

In [122]:
generator = Generator(
    ITEM_NUM, USER_NUM,EMB_DIM, lamda=0.0 / BATCH_SIZE,
    param=None, initdelta=INIT_DELTA)

discriminator = Discriminator(
    ITEM_NUM, USER_NUM,EMB_DIM, lamda=0.0 / BATCH_SIZE,
    param=None, initdelta=INIT_DELTA)

g_optimizer = torch.optim.SGD(
    generator.parameters(), lr=0.001, momentum=0.9)

d_optimizer = torch.optim.SGD(
    discriminator.parameters(), lr=0.001, momentum=0.9)                    

In [None]:
from torch import nn
from torch import autograd

generator.train()
discriminator.train()
for epoch in range(15):
    if epoch >= 0:
        for d_epoch in range(100):
            if d_epoch % 5 == 0:
                generate_for_d(generator, DIS_TRAIN_FILE)
                train_size = ut.file_len(DIS_TRAIN_FILE)
            index = 1
            while True:
                if index > train_size:
                    break
                if index + BATCH_SIZE <= train_size + 1:
                    users, items, labels = ut.get_batch_data(
                        DIS_TRAIN_FILE, index, BATCH_SIZE)
                else:
                    users, items, labels = ut.get_batch_data(
                        DIS_TRAIN_FILE, index, train_size - index + 1)
                    
                index += BATCH_SIZE
                users = torch.tensor(users).view(-1, 1)
                items = torch.tensor(items).view(-1, 1)
                labels = torch.tensor(labels).view(-1, 1)
    
                loss_d = discriminator(users, items, labels)
                d_optimizer.zero_grad()
                loss_d.backward()
                d_optimizer.step()
            print("\r[D Epoch %d/%d] [loss: %f]" %(d_epoch, 100, loss_d.item()))

        for g_epoch in range(50):
            for user in user_pos_train:
                sample_lambda = 0.2
                pos = user_pos_train[user]

                rating = generator.all_logits(torch.tensor(user))
                rating = rating.detach_().cpu().numpy()
                exp_rating = np.exp(rating)
                prob = exp_rating / np.sum(exp_rating)
                
                pn = (1 - sample_lambda) * prob
                pn[pos] += sample_lambda * 1.0 / len(pos)
                
                sample = np.random.choice(range(ITEM_NUM), 2*len(pos), p=pn)
                reward = discriminator.get_reward(torch.tensor([user] * 2 * len(pos)), torch.tensor(sample))
                reward = reward.detach_().cpu().numpy()
                reward = reward * prob[sample] / pn[sample]

                loss_g = generator(torch.tensor(user), torch.tensor(sample), torch.tensor(reward))
                g_optimizer.zero_grad()
                loss_g.backward()
                g_optimizer.step()
                
            print("\r[G Epoch %d/%d] [loss: %f]" %(g_epoch, 50, loss_g.item()))
            result = simple_test(generator)
            print("epoch ", epoch, "gen: ", result)
#             buf = '\t'.join([str(x) for x in result])
#             gen_log.write(str(epoch) + '\t' + buf + '\n')
#             gen_log.flush()

[D Epoch 0/100] [loss: 0.646515]
[D Epoch 1/100] [loss: 0.646522]
[D Epoch 2/100] [loss: 0.646529]
[D Epoch 3/100] [loss: 0.646536]
[D Epoch 4/100] [loss: 0.646543]
[D Epoch 5/100] [loss: 0.646551]
[D Epoch 6/100] [loss: 0.646559]
[D Epoch 7/100] [loss: 0.646566]
[D Epoch 8/100] [loss: 0.646574]
[D Epoch 9/100] [loss: 0.646582]
[D Epoch 10/100] [loss: 0.646591]
[D Epoch 11/100] [loss: 0.646599]
[D Epoch 12/100] [loss: 0.646607]
[D Epoch 13/100] [loss: 0.646616]
[D Epoch 14/100] [loss: 0.646625]
[D Epoch 15/100] [loss: 0.646633]
[D Epoch 16/100] [loss: 0.646642]
[D Epoch 17/100] [loss: 0.646651]
[D Epoch 18/100] [loss: 0.646660]
[D Epoch 19/100] [loss: 0.646669]
[D Epoch 20/100] [loss: 0.646679]
[D Epoch 21/100] [loss: 0.646688]
[D Epoch 22/100] [loss: 0.646697]
[D Epoch 23/100] [loss: 0.646706]
[D Epoch 24/100] [loss: 0.646715]
[D Epoch 25/100] [loss: 0.646725]
[D Epoch 26/100] [loss: 0.646734]
[D Epoch 27/100] [loss: 0.646744]
[D Epoch 28/100] [loss: 0.646754]
[D Epoch 29/100] [loss: 

[G Epoch 28/50] [loss: 0.235173]
epoch  0 gen:  [0.33187134502924015, 0.3065789473684209, 0.27105263157894705, 0.34321419685315352, 0.32657025711280918, 0.31260383548307924]
[G Epoch 29/50] [loss: 0.225934]
epoch  0 gen:  [0.33187134502924015, 0.30657894736842095, 0.27083333333333304, 0.34321419685315352, 0.32653766221734787, 0.31246287363503827]
[G Epoch 30/50] [loss: 0.216859]
epoch  0 gen:  [0.33187134502924015, 0.30657894736842095, 0.27083333333333304, 0.34321419685315352, 0.32653766221734787, 0.31246287363503827]
[G Epoch 31/50] [loss: 0.089096]
epoch  0 gen:  [0.33187134502924015, 0.30657894736842095, 0.27061403508771897, 0.34359401443003268, 0.32681216647291611, 0.31250148884167334]
[G Epoch 32/50] [loss: 0.397192]
epoch  0 gen:  [0.33187134502924015, 0.30657894736842095, 0.27039473684210497, 0.34359401443003268, 0.32681216647291611, 0.31236196956930723]
[G Epoch 33/50] [loss: -0.160964]
epoch  0 gen:  [0.33187134502924015, 0.30657894736842095, 0.27061403508771897, 0.34359401443

[G Epoch 6/50] [loss: -0.234353]
epoch  1 gen:  [0.33187134502924009, 0.30614035087719282, 0.27061403508771897, 0.34359401443003268, 0.32658962583069762, 0.31254123691183583]
[G Epoch 7/50] [loss: 0.127489]
epoch  1 gen:  [0.33187134502924009, 0.30614035087719282, 0.27061403508771897, 0.34359401443003268, 0.32658962583069762, 0.31254260685826479]
[G Epoch 8/50] [loss: 0.103808]
epoch  1 gen:  [0.33187134502924009, 0.30614035087719282, 0.27061403508771897, 0.34359401443003268, 0.32665481562162024, 0.31259822042082414]
[G Epoch 9/50] [loss: 0.210702]
epoch  1 gen:  [0.33187134502924009, 0.30614035087719282, 0.27061403508771897, 0.34359401443003268, 0.32665481562162024, 0.31259537115567687]
[G Epoch 10/50] [loss: 0.289559]
epoch  1 gen:  [0.33187134502924009, 0.30614035087719282, 0.27061403508771897, 0.34359401443003268, 0.32662222072615893, 0.3125742192988174]
[G Epoch 11/50] [loss: 0.071249]
epoch  1 gen:  [0.33114035087719335, 0.30614035087719282, 0.27061403508771897, 0.343079454423428

[D Epoch 15/100] [loss: 0.651705]
[D Epoch 16/100] [loss: 0.651749]
[D Epoch 17/100] [loss: 0.651793]
[D Epoch 18/100] [loss: 0.651836]
[D Epoch 19/100] [loss: 0.651880]
[D Epoch 20/100] [loss: 0.651925]
[D Epoch 21/100] [loss: 0.651969]
[D Epoch 22/100] [loss: 0.652013]
[D Epoch 23/100] [loss: 0.652058]
[D Epoch 24/100] [loss: 0.652102]
[D Epoch 25/100] [loss: 0.652147]
[D Epoch 26/100] [loss: 0.652192]
[D Epoch 27/100] [loss: 0.652236]
[D Epoch 28/100] [loss: 0.652281]
[D Epoch 29/100] [loss: 0.652326]
[D Epoch 30/100] [loss: 0.652371]
[D Epoch 31/100] [loss: 0.652416]
[D Epoch 32/100] [loss: 0.652461]
[D Epoch 33/100] [loss: 0.652507]
[D Epoch 34/100] [loss: 0.652552]
[D Epoch 35/100] [loss: 0.652597]
[D Epoch 36/100] [loss: 0.652643]
[D Epoch 37/100] [loss: 0.652688]
[D Epoch 38/100] [loss: 0.652734]
[D Epoch 39/100] [loss: 0.652780]
[D Epoch 40/100] [loss: 0.652826]
[D Epoch 41/100] [loss: 0.652872]
[D Epoch 42/100] [loss: 0.652918]
[D Epoch 43/100] [loss: 0.652964]
[D Epoch 44/10

[G Epoch 31/50] [loss: -0.096099]
epoch  2 gen:  [0.33260233918128684, 0.30570175438596481, 0.27105263157894699, 0.34437805929608606, 0.32653351354523757, 0.31344220916070792]
[G Epoch 32/50] [loss: -0.128449]
epoch  2 gen:  [0.33187134502924009, 0.30570175438596481, 0.27083333333333298, 0.34348368171260302, 0.32614225863735619, 0.31304086256223573]
[G Epoch 33/50] [loss: 0.134594]
epoch  2 gen:  [0.33260233918128684, 0.30570175438596481, 0.27061403508771892, 0.34399824171920695, 0.32616122460328545, 0.31289673917231825]
[G Epoch 34/50] [loss: -0.062441]
epoch  2 gen:  [0.33187134502924009, 0.30570175438596481, 0.27061403508771892, 0.34386349928948218, 0.32637638129586033, 0.31300855898639113]
[G Epoch 35/50] [loss: -0.678114]
epoch  2 gen:  [0.33187134502924009, 0.30570175438596481, 0.27083333333333298, 0.34348368171260302, 0.32616706683121482, 0.31303543164702607]
[G Epoch 36/50] [loss: -0.554803]
epoch  2 gen:  [0.33187134502924009, 0.30570175438596481, 0.27083333333333298, 0.343483

[G Epoch 9/50] [loss: -0.391191]
epoch  3 gen:  [0.33260233918128684, 0.30394736842105258, 0.27039473684210491, 0.34424331686636128, 0.32528520987824833, 0.31293476154731975]
[G Epoch 10/50] [loss: -0.773073]
epoch  3 gen:  [0.33260233918128684, 0.30394736842105258, 0.27039473684210491, 0.34437805929608606, 0.32539729839523229, 0.31302960137938007]
[G Epoch 11/50] [loss: -0.661010]
epoch  3 gen:  [0.33260233918128684, 0.30394736842105258, 0.27017543859649085, 0.34437805929608606, 0.32539729839523229, 0.31288401487110346]
[G Epoch 12/50] [loss: -0.262926]
epoch  3 gen:  [0.33333333333333365, 0.30394736842105258, 0.27039473684210491, 0.34489261930269, 0.32544885925662281, 0.31306561795693905]
[G Epoch 13/50] [loss: -0.805419]
epoch  3 gen:  [0.33333333333333365, 0.30394736842105258, 0.27017543859649085, 0.34489261930269, 0.32544885925662281, 0.31291892656721942]
[G Epoch 14/50] [loss: -0.603226]
epoch  3 gen:  [0.33333333333333365, 0.30394736842105258, 0.27017543859649085, 0.344892619302

[D Epoch 30/100] [loss: 0.661610]
[D Epoch 31/100] [loss: 0.661652]
[D Epoch 32/100] [loss: 0.661694]
[D Epoch 33/100] [loss: 0.661736]
[D Epoch 34/100] [loss: 0.661778]
[D Epoch 35/100] [loss: 0.661819]
[D Epoch 36/100] [loss: 0.661861]
[D Epoch 37/100] [loss: 0.661903]
[D Epoch 38/100] [loss: 0.661944]
[D Epoch 39/100] [loss: 0.661986]
[D Epoch 40/100] [loss: 0.662027]
[D Epoch 41/100] [loss: 0.662068]
[D Epoch 42/100] [loss: 0.662110]
[D Epoch 43/100] [loss: 0.662151]
[D Epoch 44/100] [loss: 0.662192]
[D Epoch 45/100] [loss: 0.662233]
[D Epoch 46/100] [loss: 0.662274]
[D Epoch 47/100] [loss: 0.662315]
[D Epoch 48/100] [loss: 0.662356]
[D Epoch 49/100] [loss: 0.662396]
[D Epoch 50/100] [loss: 0.662437]
[D Epoch 51/100] [loss: 0.662478]
[D Epoch 52/100] [loss: 0.662518]
[D Epoch 53/100] [loss: 0.662558]
[D Epoch 54/100] [loss: 0.662599]
[D Epoch 55/100] [loss: 0.662639]
[D Epoch 56/100] [loss: 0.662679]
[D Epoch 57/100] [loss: 0.662719]
[D Epoch 58/100] [loss: 0.662759]
[D Epoch 59/10

epoch  4 gen:  [0.33552631578947406, 0.30701754385964902, 0.27039473684210485, 0.34679781651016933, 0.32797917206044996, 0.31365934685306629]
[G Epoch 34/50] [loss: -0.369299]
epoch  4 gen:  [0.33552631578947406, 0.30657894736842095, 0.27039473684210485, 0.34717763408704849, 0.32806373056926103, 0.31392781754724192]
[G Epoch 35/50] [loss: -0.823994]
epoch  4 gen:  [0.33552631578947406, 0.30657894736842095, 0.27039473684210485, 0.34717763408704849, 0.32803113567379971, 0.31388288327674829]
[G Epoch 36/50] [loss: -0.910432]
epoch  4 gen:  [0.33552631578947406, 0.30657894736842095, 0.27039473684210485, 0.34717763408704849, 0.32803113567379971, 0.3138742507815927]
[G Epoch 37/50] [loss: -0.847303]
epoch  4 gen:  [0.33552631578947406, 0.30614035087719282, 0.27061403508771892, 0.34704289165732372, 0.3276134284107724, 0.31390345063206165]
[G Epoch 38/50] [loss: -1.003228]
epoch  4 gen:  [0.33552631578947406, 0.30570175438596475, 0.27039473684210485, 0.34704289165732372, 0.32732569797763134, 0

[G Epoch 11/50] [loss: -0.707366]
epoch  5 gen:  [0.33479532163742726, 0.30657894736842095, 0.27039473684210485, 0.34652833165071989, 0.32784959798252294, 0.31377969369343756]
[G Epoch 12/50] [loss: -0.903466]
epoch  5 gen:  [0.33479532163742726, 0.30614035087719293, 0.27061403508771892, 0.34666307408044461, 0.32769184437926802, 0.31401027944559701]
[G Epoch 13/50] [loss: -0.458240]
epoch  5 gen:  [0.33479532163742726, 0.30614035087719293, 0.27061403508771892, 0.34666307408044461, 0.32762665458834539, 0.31396613579856991]
[G Epoch 14/50] [loss: -1.021748]
epoch  5 gen:  [0.33479532163742726, 0.30614035087719293, 0.27061403508771892, 0.34666307408044461, 0.32769184437926802, 0.31401435257804639]
[G Epoch 15/50] [loss: -1.131844]
epoch  5 gen:  [0.33479532163742726, 0.30614035087719293, 0.27061403508771892, 0.34666307408044461, 0.32769184437926802, 0.31401435257804639]
[G Epoch 16/50] [loss: -1.074906]
epoch  5 gen:  [0.33479532163742726, 0.30570175438596481, 0.27061403508771892, 0.34666

[D Epoch 40/100] [loss: 0.668771]
[D Epoch 41/100] [loss: 0.668797]
[D Epoch 42/100] [loss: 0.668823]
[D Epoch 43/100] [loss: 0.668849]
[D Epoch 44/100] [loss: 0.668875]
[D Epoch 45/100] [loss: 0.668901]
[D Epoch 46/100] [loss: 0.668927]
[D Epoch 47/100] [loss: 0.668953]
[D Epoch 48/100] [loss: 0.668979]
[D Epoch 49/100] [loss: 0.669004]
[D Epoch 50/100] [loss: 0.669030]
[D Epoch 51/100] [loss: 0.669056]
[D Epoch 52/100] [loss: 0.669081]
[D Epoch 53/100] [loss: 0.669107]
[D Epoch 54/100] [loss: 0.669132]
[D Epoch 55/100] [loss: 0.669157]
[D Epoch 56/100] [loss: 0.669183]
[D Epoch 57/100] [loss: 0.669208]
[D Epoch 58/100] [loss: 0.669233]
[D Epoch 59/100] [loss: 0.669258]
[D Epoch 60/100] [loss: 0.669283]
[D Epoch 61/100] [loss: 0.669308]
[D Epoch 62/100] [loss: 0.669333]
[D Epoch 63/100] [loss: 0.669358]
[D Epoch 64/100] [loss: 0.669383]
[D Epoch 65/100] [loss: 0.669407]
[D Epoch 66/100] [loss: 0.669432]
[D Epoch 67/100] [loss: 0.669457]
[D Epoch 68/100] [loss: 0.669481]
[D Epoch 69/10

epoch  6 gen:  [0.33479532163742726, 0.30526315789473668, 0.27039473684210491, 0.34625884679127034, 0.3269220223960953, 0.31368537768238031]
[G Epoch 36/50] [loss: -0.543295]
epoch  6 gen:  [0.33479532163742726, 0.30526315789473668, 0.27061403508771892, 0.34639358922099511, 0.32701940433052012, 0.31387369174708518]
[G Epoch 37/50] [loss: -0.784452]
epoch  6 gen:  [0.33479532163742726, 0.30614035087719282, 0.27061403508771892, 0.34652833165071983, 0.32769224713122713, 0.31397938660092595]
[G Epoch 38/50] [loss: -1.222783]
epoch  6 gen:  [0.33479532163742726, 0.30614035087719282, 0.27061403508771892, 0.34652833165071983, 0.32769224713122713, 0.31398516167824048]
[G Epoch 39/50] [loss: -1.068835]
epoch  6 gen:  [0.33479532163742726, 0.30614035087719293, 0.27039473684210485, 0.34639358922099511, 0.32759486519680231, 0.31378244819375922]
[G Epoch 40/50] [loss: -1.099664]
epoch  6 gen:  [0.33479532163742726, 0.30657894736842095, 0.27083333333333293, 0.34639358922099511, 0.32788259562994337, 

[G Epoch 13/50] [loss: -0.774715]
epoch  7 gen:  [0.33625730994152087, 0.30570175438596481, 0.27039473684210491, 0.34728796680447815, 0.32728027965655626, 0.31373723224749228]
[G Epoch 14/50] [loss: -1.069206]
epoch  7 gen:  [0.33625730994152087, 0.30614035087719293, 0.27061403508771892, 0.34728796680447815, 0.32753541519423601, 0.3138845706915469]
[G Epoch 15/50] [loss: -1.505967]
epoch  7 gen:  [0.33552631578947406, 0.30657894736842095, 0.27039473684210485, 0.34677340679787422, 0.32777158476598656, 0.31372741821205563]
[G Epoch 16/50] [loss: -1.229196]
epoch  7 gen:  [0.33552631578947406, 0.30657894736842095, 0.27061403508771892, 0.34677340679787422, 0.32773898987052524, 0.31384412838703557]
[G Epoch 17/50] [loss: -1.216997]
epoch  7 gen:  [0.33625730994152087, 0.30657894736842095, 0.27061403508771892, 0.34728796680447815, 0.32779055073191582, 0.31390587677735127]
[G Epoch 18/50] [loss: -1.073790]
epoch  7 gen:  [0.33552631578947406, 0.30657894736842095, 0.27039473684210491, 0.346773

[D Epoch 51/100] [loss: 0.673030]
[D Epoch 52/100] [loss: 0.673045]
[D Epoch 53/100] [loss: 0.673060]
[D Epoch 54/100] [loss: 0.673075]
[D Epoch 55/100] [loss: 0.673089]
[D Epoch 56/100] [loss: 0.673104]
[D Epoch 57/100] [loss: 0.673119]
[D Epoch 58/100] [loss: 0.673133]
[D Epoch 59/100] [loss: 0.673148]
[D Epoch 60/100] [loss: 0.673162]
[D Epoch 61/100] [loss: 0.673177]
[D Epoch 62/100] [loss: 0.673191]
[D Epoch 63/100] [loss: 0.673206]
[D Epoch 64/100] [loss: 0.673220]
[D Epoch 65/100] [loss: 0.673234]
[D Epoch 66/100] [loss: 0.673248]
[D Epoch 67/100] [loss: 0.673263]
[D Epoch 68/100] [loss: 0.673276]
[D Epoch 69/100] [loss: 0.673291]
[D Epoch 70/100] [loss: 0.673305]
[D Epoch 71/100] [loss: 0.673318]
[D Epoch 72/100] [loss: 0.673332]
[D Epoch 73/100] [loss: 0.673346]
[D Epoch 74/100] [loss: 0.673360]
[D Epoch 75/100] [loss: 0.673374]
[D Epoch 76/100] [loss: 0.673388]
[D Epoch 77/100] [loss: 0.673401]
[D Epoch 78/100] [loss: 0.673415]
[D Epoch 79/100] [loss: 0.673429]
[D Epoch 80/10

[G Epoch 38/50] [loss: -1.114574]
epoch  8 gen:  [0.33771929824561442, 0.30570175438596475, 0.271271929824561, 0.34849924704176022, 0.32741010092854733, 0.31444583563599809]
[G Epoch 39/50] [loss: -0.779651]
epoch  8 gen:  [0.33771929824561442, 0.30570175438596475, 0.271271929824561, 0.34836450461203544, 0.32724752920319988, 0.31432027032910032]
[G Epoch 40/50] [loss: -0.907017]
epoch  8 gen:  [0.33771929824561442, 0.30570175438596475, 0.271271929824561, 0.34836450461203544, 0.32724752920319988, 0.31432889467156211]
[G Epoch 41/50] [loss: -1.351862]
epoch  8 gen:  [0.33771929824561442, 0.30570175438596475, 0.27149122807017506, 0.34849924704176022, 0.32741010092854733, 0.31459397925082599]
[G Epoch 42/50] [loss: -1.402904]
epoch  8 gen:  [0.33771929824561442, 0.30570175438596475, 0.27149122807017506, 0.34863398947148494, 0.32750748286297215, 0.31468273684278902]
[G Epoch 43/50] [loss: -0.829934]
epoch  8 gen:  [0.33771929824561442, 0.30570175438596475, 0.27149122807017506, 0.34863398947

[G Epoch 16/50] [loss: -1.129186]
epoch  9 gen:  [0.33698830409356767, 0.30526315789473668, 0.27105263157894693, 0.34836450461203544, 0.32735812750496462, 0.31454476777155582]
[G Epoch 17/50] [loss: -1.196146]
epoch  9 gen:  [0.33771929824561442, 0.30526315789473668, 0.27105263157894693, 0.348633989471485, 0.32719997114975052, 0.31444565077132475]
[G Epoch 18/50] [loss: -1.194298]
epoch  9 gen:  [0.33698830409356767, 0.30526315789473668, 0.27105263157894693, 0.34849924704176016, 0.32745550943938945, 0.31461367220567898]
[G Epoch 19/50] [loss: -0.965334]
epoch  9 gen:  [0.33698830409356767, 0.30526315789473668, 0.271271929824561, 0.34836450461203544, 0.32726034281858069, 0.31463693711171459]
[G Epoch 20/50] [loss: -1.042768]
epoch  9 gen:  [0.33698830409356767, 0.30526315789473668, 0.27105263157894693, 0.34901380704836404, 0.32776220583845977, 0.31481757835248619]
[G Epoch 21/50] [loss: -0.756589]
epoch  9 gen:  [0.33771929824561442, 0.30526315789473668, 0.27149122807017501, 0.349393624

[D Epoch 67/100] [loss: 0.675289]
[D Epoch 68/100] [loss: 0.675295]
[D Epoch 69/100] [loss: 0.675300]
[D Epoch 70/100] [loss: 0.675305]
[D Epoch 71/100] [loss: 0.675310]
[D Epoch 72/100] [loss: 0.675315]
[D Epoch 73/100] [loss: 0.675320]
[D Epoch 74/100] [loss: 0.675325]
[D Epoch 75/100] [loss: 0.675330]
[D Epoch 76/100] [loss: 0.675335]
[D Epoch 77/100] [loss: 0.675340]
[D Epoch 78/100] [loss: 0.675345]
[D Epoch 79/100] [loss: 0.675349]
[D Epoch 80/100] [loss: 0.675354]
[D Epoch 81/100] [loss: 0.675359]
[D Epoch 82/100] [loss: 0.675363]
[D Epoch 83/100] [loss: 0.675368]
[D Epoch 84/100] [loss: 0.675372]
[D Epoch 85/100] [loss: 0.675376]
[D Epoch 86/100] [loss: 0.675381]
[D Epoch 87/100] [loss: 0.675385]
[D Epoch 88/100] [loss: 0.675389]
[D Epoch 89/100] [loss: 0.675393]
[D Epoch 90/100] [loss: 0.675397]
[D Epoch 91/100] [loss: 0.675401]
[D Epoch 92/100] [loss: 0.675405]
[D Epoch 93/100] [loss: 0.675409]
[D Epoch 94/100] [loss: 0.675412]
[D Epoch 95/100] [loss: 0.675416]
[D Epoch 96/10

[G Epoch 41/50] [loss: -0.994186]
epoch  10 gen:  [0.33625730994152087, 0.30394736842105263, 0.27149122807017501, 0.34771520217570684, 0.32628080361880318, 0.31474386803041859]
[G Epoch 42/50] [loss: -0.882785]
epoch  10 gen:  [0.33625730994152087, 0.30394736842105263, 0.27149122807017501, 0.34758045974598212, 0.32618342168437836, 0.31468067381830339]
[G Epoch 43/50] [loss: -1.479667]
epoch  10 gen:  [0.33625730994152087, 0.30394736842105263, 0.27105263157894693, 0.34771520217570684, 0.3263133985142645, 0.31449979536255551]
[G Epoch 44/50] [loss: -0.760995]
epoch  10 gen:  [0.33625730994152087, 0.30438596491228065, 0.271271929824561, 0.34809501975258594, 0.3268756332029738, 0.31483809100554655]
[G Epoch 45/50] [loss: -1.153921]
epoch  10 gen:  [0.33698830409356767, 0.30438596491228065, 0.271271929824561, 0.34860957975918982, 0.32692719406436432, 0.31489195925059871]
[G Epoch 46/50] [loss: -1.538365]
epoch  10 gen:  [0.33625730994152087, 0.30438596491228065, 0.271271929824561, 0.3480950

In [116]:
generator.user_embeddings.weight = torch.nn.Parameter(torch.tensor(param[0]))
generator.item_embeddings.weight = torch.nn.Parameter(torch.tensor(param[1]))

In [108]:
generator

Generator(
  (user_embeddings): Embedding(943, 5)
  (item_embeddings): Embedding(1683, 5)
  (item_bias): Embedding(1683, 1)
)