In [None]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [None]:
import pickle

feature = pickle.load(open('Dataset/dataset_ytsum/feature.pkl', 'rb'))
summary = pickle.load(open('Dataset/dataset_ytsum/summary.pkl', 'rb'))

print(len(summary))
print(len(feature), list(feature.keys()))

In [None]:
import numpy as np

import torch as T
import torch.nn as nn

from torch.utils.data import Dataset, DataLoader

class DS_YTSum(Dataset):
    def __init__(self, feature, summary):
        super(DS_YTSum, self).__init__()
        
        self.feature = feature
        self.summary = summary
        self.key = list(self.summary.keys())
        
        MXL_ENC, MXL_DEC = 0, 0
        for k in self.key:
            MXL_ENC = max(MXL_ENC, len(self.feature[k]))
            MXL_DEC = max(MXL_DEC, len(self.summary[k]))
        
        MXL_ENC += 4
        MXL_DEC += 4
        
        self.enc = dict()
        self.dec = dict()
        self.ans = dict()
        
        self.l_in = dict()
        self.l_ot = dict()
        self.wgt = dict()
        
        for k in self.key:
            enc = np.zeros((MXL_ENC, 4096), dtype=np.float32)
            dec = np.zeros((MXL_DEC, 4096), dtype=np.float32)
            ans = np.zeros((MXL_DEC, ), dtype=np.int)
            wgt = np.zeros((MXL_DEC, ), dtype=np.float32)
            
            self.l_in[k], self.l_ot[k] = len(self.feature[k])+1, len(self.summary[k])+1
            
            for i in range(self.l_in[k]-1):
                enc[i] = self.feature[k][i]
            for i in range(self.l_in[k]-1, MXL_ENC):
                enc[i] = self.feature['END']
            
            dec[0] = self.feature['START']
            for i in range(1, self.l_ot[k]):
                dec[i] = self.feature[k][self.summary[k][i-1]-1]
            for i in range(self.l_ot[k], MXL_DEC):
                dec[i] = self.feature['END']
            
            for i in range(self.l_ot[k]-1):
                ans[i] = self.summary[k][i]-1
                wgt[i] = 1
            wgt[self.l_ot[k]-1] = 1
            for i in range(self.l_ot[k]-1, MXL_DEC):
                ans[i] = self.l_in[k]-1
            
            self.enc[k] = enc
            self.dec[k] = dec
            self.ans[k] = ans
            self.wgt[k] = wgt
    
    def __len__(self):
        return len(self.key)
    
    def __getitem__(self, idx):
        k = self.key[idx]
        
        return self.enc[k], self.dec[k], self.ans[k], self.l_in[k], self.l_ot[k], self.wgt[k]

ld = DataLoader(DS_YTSum(feature, summary), batch_size=8, shuffle=True)
for inp_enc, inp_dec, ans, inp_lin, inp_lot, wgt in ld:
    print(inp_enc.shape, inp_dec.shape)
    print(ans.shape)
    print(inp_lin.shape, inp_lot.shape)
    print(wgt.shape)
    
    print()
    print(inp_dec[0][0])
    print(inp_enc[0][inp_lin[0]-1])
    
    print()
    print(ans[0])
    print(inp_lin[0], inp_lot[0])
    
    print()
    print(wgt[0])
    
    break

In [None]:
import math

class Model_G(nn.Module):
    def __init__(self, hid=256, dp=0.2):
        super(Model_G, self).__init__()
        
        self.hid = hid
        self.dp = dp
        
        self.gru_enc = nn.GRU(4096, hid, 
                              num_layers=2, batch_first=True, dropout=dp, bidirectional=True)
        self.gru_dec = nn.GRU(4096, hid*2, 
                              batch_first=True)
        self.att = nn.Parameter(T.FloatTensor(hid*2, hid*2))
        
        self.init()
        
        self.START = T.from_numpy(feature['START']).view((1, 1, 4096)).float().cuda()
        
    def init(self):
        stdv = 1/math.sqrt(self.hid*2)
        
        self.att.data.uniform_(stdv, -stdv)
    
    def rollout(self, inp_enc, out_enc, h, M=1, policy='greedy'):        
        STH = h
        
        outs = []
        probs = []
        poss = []
        for i in range(M):
            out = []
            prob = []
            pos = []
            
            for j in range(22):
                tmp, h = self.gru_dec(self.START if j==0 else tmp, STH if j==0 else h)
                tmp = T.bmm(T.matmul(tmp, self.att), out_enc.transpose(1, 2))
                out.append(tmp)
                
                tmp = nn.functional.log_softmax(tmp, dim=2)
                prob.append(tmp)
                
                pb = T.exp(tmp).data.cpu().numpy()[0][0]
                if policy=='greedy':
                    p = np.argmax(pb)
                else:
                    p = np.random.choice([i for i in range(len(pb))], p=pb)
                pos.append(p)
                
                tmp = inp_enc[:, p:p+1, :]
            
            out = T.cat(out, dim=1)
            prob = T.cat(prob, dim=1)
            
            outs.append(out)
            probs.append(prob)
            poss.append(pos)
            
        out = T.cat(outs, dim=0)
        prob = T.cat(probs, dim=0)
        pos = poss
        
        return out, prob, pos
    
    def forward(self, inp_enc,  
                inp_dec=None, is_tr=True, is_pg=False, M=8):
        batch = inp_enc.shape[0]
        
        outs = []
        probs = []
        poss = []
        for i in range(batch):
            out_enc, h = self.gru_enc(inp_enc[i:i+1])
            h = h[-2:].view((1, 1, self.hid*2))
            
            if is_tr==True:
                if is_pg==False:
                    out_dec, _ = self.gru_dec(inp_dec[i:i+1], h)
                    out = T.bmm(T.matmul(out_dec, self.att), out_enc.transpose(1, 2))
                    
                    outs.append(out)
                
                else:
                    out, prob, pos = self.rollout(inp_enc[i:i+1], out_enc, h, M=M, policy='sample')
                    
                    outs.append(out)
                    probs.append(prob)
                    poss.append(pos)
            else:
                out, prob, pos = self.rollout(inp_enc[i:i+1], out_enc, h, M=1, policy='greedy')
                
                outs.append(out)
                probs.append(prob)
                poss.append(pos)
        
        out = T.cat(outs, dim=0)
        if not len(probs)==0:
            prob = T.cat(probs, dim=0)
        else:
            prob = None
        pos = [p for pp in poss for p in pp]
        
        return out, prob, pos

model_g = Model_G().cuda()
print(model_g)

out, _, _ = model_g(inp_enc.cuda(), inp_dec.cuda(), is_tr=True, is_pg=False)
print(out.shape)

out, prob, pos = model_g(inp_enc.cuda(), is_tr=False)
print(out.shape, prob.shape, len(pos), len(pos[0]))

out, prob, pos = model_g(inp_enc.cuda(), is_tr=True, is_pg=True)
print(out.shape, prob.shape, len(pos), len(pos[0]))

In [None]:
from tqdm import tqdm_notebook as tqdm

In [None]:
EPOCHS = 120

loss_ce = nn.CrossEntropyLoss(reduction='none').cuda()
optim = T.optim.Adam(model_g.parameters(), lr=0.00008)

for e in tqdm(range(EPOCHS)):
    ls_ep = 0
    
    model_g.train()
    with tqdm(ld) as TQ:
        for inp_enc, inp_dec, ans, inp_lin, inp_lot, wgt in TQ:
            out, _, _ = model_g(inp_enc.cuda(), inp_dec.cuda(), is_tr=True, is_pg=False)
            
            out = out.view((-1, 578))
            ans = ans.view((-1, )).cuda()
            wgt = wgt.view((-1, )).cuda()
            
            ls_bh = loss_ce(out, ans)*wgt
            ls_bh = ls_bh.sum()/wgt.sum()
            
            optim.zero_grad()
            ls_bh.backward()
            optim.step()
            
            ls_bh = ls_bh.cpu().detach().numpy()
            ls_ep += ls_bh
            
            TQ.set_postfix(ls_bh='%.3f'%(ls_bh))
            
        ls_ep /= len(TQ)
        print('Ep %d: %.4f' % (e+1, ls_ep))

In [None]:
model_g.load_state_dict(T.load('Model/only_g.pt'))

model_g.eval()
for inp_enc, inp_dec, ans, inp_lin, inp_lot, wgt in ld:
    _, _, pos = model_g(inp_enc.cuda(), inp_dec.cuda(), is_tr=False)
    
    ans = ans.numpy()
    
    print(pos[0])
    print(ans[0])
    print('----------')

In [None]:
class Model_D(nn.Module):
    def __init__(self):
        super(Model_D, self).__init__()
        
        self.cnn = nn.Sequential(*[nn.Conv1d(4096, 256, 3, padding=1), nn.ReLU(), 
                                   nn.Conv1d(256, 128, 3, padding=1), nn.ReLU(), 
                                   nn.Conv1d(128, 128, 3), nn.ReLU(), 
                                   nn.Conv1d(128, 64, 3), nn.ReLU()])
        self.fc = nn.Sequential(*[nn.Linear(64, 16), nn.ReLU(), 
                                  nn.Linear(16, 16), nn.ReLU(), 
                                  nn.Linear(16, 1), nn.Sigmoid()])
        
    def forward(self, inp):
        inp = inp.transpose(1, 2)
        out = self.cnn(inp).view((inp.shape[0], -1))
        out = self.fc(out)
        
        return out

model_d = Model_D().cuda()
print(model_d)

inp = T.rand((8, 5, 4096)).cuda()
out = model_d(inp)
print(out.shape)

In [None]:
class Loss_PG(nn.Module):
    def __init__(self):
        super(Loss_PG, self).__init__()
        
        self.bl = 0
        self.bn = 0
        
    def forward(self, prob, pos, lot, rwd):
        batch = prob.shape[0]
        
        ls = 0
        cnt = 0
        for i in range(batch):
            for j in range(lot[i]):
                cnt += 1
                ls += -prob[i][j][pos[i][j]]*(rwd[i]-self.bl)
        ls /= cnt
        
        self.bl = (self.bl*self.bn + np.average(rwd))/(self.bn+1)
        self.bn += 1
        
        return ls

EPOCHS = 40   
    
loss_mse = nn.MSELoss().cuda()
loss_pg = Loss_PG().cuda()

optim_g = T.optim.Adam(model_g.parameters(), lr=0.0001)
optim_d = T.optim.Adam(model_d.parameters(), lr=0.0001)

In [None]:
def build_cube(inp, pos, lot):
    lot = lot.numpy()
    pos = sorted(np.random.choice(pos[:lot-1], 5, replace=True if (lot-1)<5 else False))
    cube = T.cat([inp[pos[i]].view((1, 1, -1)) for i in range(5)], dim=1)
    
    return cube.cuda()

for e in tqdm(range(EPOCHS)):
    ls_g_ep = 0
    ls_d_ep = 0
    
    model_g.train()
    with tqdm(ld) as TQ:
        for inp_enc, inp_dec, ans, inp_lin, inp_lot, wgt in TQ:
            out, prob, pos = model_g(inp_enc.cuda(), inp_dec.cuda(), is_tr=True, is_pg=True, M=4)
            
            # update G
            cube = T.cat([build_cube(inp_enc[i//4], pos[i], inp_lot[i//4]) for i in range(out.shape[0])], dim=0)
            rwd = 1-model_d(cube).data.cpu().numpy().squeeze()
            ls_bh = loss_pg(prob, pos, np.repeat(inp_lot.numpy(), 4), rwd)
            
            optim_g.zero_grad()
            ls_bh.backward()
            optim_g.step()
            
            ls_bh = ls_bh.cpu().detach().numpy()
            ls_g_ep += ls_bh
            
            # update D (false-case)
            out = model_d(cube)
            ls_bh = loss_mse(out, T.zeros((out.shape[0])).cuda())
            
            optim_d.zero_grad()
            ls_bh.backward()
            optim_d.step()
            
            ls_bh = ls_bh.cpu().detach().numpy()
            ls_d_ep += ls_bh
            
            # update D (true-case)
            cube = T.cat([build_cube(inp_enc[i], ans.numpy()[i], inp_lot[i]) for i in range(inp_enc.shape[0])], dim=0)
            
            out = model_d(cube)
            ls_bh = loss_mse(out, T.ones((out.shape[0])).cuda())
            
            optim_d.zero_grad()
            ls_bh.backward()
            optim_d.step()
            
            ls_bh = ls_bh.cpu().detach().numpy()
            ls_d_ep += ls_bh
        
        ls_g_ep /= len(TQ)
        ls_d_ep /= len(TQ)*2
        print('Ep %d: G-%.3f D-%.3f' % (e+1, ls_g_ep, ls_d_ep))

In [None]:
model_d.load_state_dict(T.load('Model/only_d.pt'))
model_g.load_state_dict(T.load('Model/g+d.pt'))

In [None]:
T.save(model_d.state_dict(), 'Model/only_d.pt')