In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import numpy as np
import pandas as pd
import warnings; warnings.filterwarnings('ignore')
from keras.preprocessing.text import Tokenizer
from keras.preprocessing import sequence
import random

Using TensorFlow backend.


In [2]:
from data_loader import MRLoader
batch_size = 1

loader = MRLoader(batch_size)
train_loader, test_loader = loader.get_dataset()
tokenizer = loader.tokenizer

In [3]:
from data_loader import Glove
glove = Glove(300)
vocab_size, embedding_dim = glove.vocab_size, glove.embedding_dim
embedding_matrix = glove.get_embedding(tokenizer)

In [4]:
class CNet(nn.Module):
    def __init__(self):
        super(CNet, self).__init__()
        
        self.input_layer = nn.Linear(300, 128)
        self.output_layer = nn.Linear(128, 1)
        nn.init.xavier_uniform_(self.input_layer.weight)
        nn.init.xavier_uniform_(self.output_layer.weight)
#         self.dropout = nn.Dropout(0.5)

    def forward(self, x):
#         x = self.dropout(x)
        x = F.relu(self.input_layer(x))
        x = F.sigmoid(self.output_layer(x))
        return x

In [5]:
class REINFORCE(nn.Module):
    def __init__(self, state_size, action_size):
        super(REINFORCE, self).__init__()

        self.state_size = state_size
        self.action_size = action_size
        
        self.input_layer = nn.Linear(self.state_size, 256)
        self.hidden_layer = nn.Linear(256, 256)
        self.output_layer = nn.Linear(256, self.action_size)
        nn.init.xavier_uniform_(self.input_layer.weight)
        nn.init.xavier_uniform_(self.hidden_layer.weight)
        nn.init.xavier_uniform_(self.output_layer.weight)
        
    def forward(self, state):
        s0, s1, s2 = state.shape
        x = F.relu(self.input_layer(state.reshape(s1*s2,)))
        x = F.relu(self.hidden_layer(x))
        x = F.softmax(self.output_layer(x))
        return x

In [6]:
class REINFORCEAgent:
    def __init__(self, state_size, action_size):
        self.state_size = state_size
        self.action_size = action_size
        
        self.discount_factor = 0.99
        self.learning_rate = 0.01
        self.states, self.actions, self.rewards, self.log_probs = [], [], [], []
        self.pre_trained = False
        
        self.model = REINFORCE(state_size, action_size).cuda()
        
        self.model_loss = torch.nn.MSELoss()
        self.model_optim = torch.optim.Adam(
            self.model.parameters(), self.learning_rate
        )

    def use_pretrained(self, filename):
        self.model.load_state_dict(torch.load(filename))

    def get_action(self, state):
        policy = self.model(state)
        action = np.random.choice(self.action_size, 1, p=policy.clone().cpu().detach().numpy())[0]
        log_prob = torch.log(policy.squeeze(0)[action])
        return action, log_prob

    def append_sample(self, state, action, reward, log_prob):
        self.states.append(state)
        act = torch.FloatTensor(self.action_size).fill_(0)
        act[action] = 1
        self.actions.append(act)
        self.rewards.append(reward)
        self.log_probs.append(log_prob)
        
    def discount_rewards(self, rewards):
        discounted_rewards = torch.FloatTensor(len(rewards)).fill_(0)
        running_add = 0
        for t in reversed(range(0, len(rewards))):
            running_add = running_add * self.discount_factor + rewards[t]
            discounted_rewards[t] = running_add
        return discounted_rewards

    def train_model(self):   
        discounted_rewards = self.discount_rewards(self.rewards)
        std = discounted_rewards.std()
        std = 1 if std.item() == 0 else std
        
        discounted_rewards -= discounted_rewards.mean()
        discounted_rewards /= std
#         print(self.log_probs, discounted_rewards)
        policy_gradient = [-prob * G for prob, G in zip(self.log_probs, discounted_rewards)]
        
        
        self.model_optim.zero_grad()
#         print(policy_gradient)
        loss = torch.stack(policy_gradient).sum()
        loss.backward()
        self.model_optim.step()
        
        self.states, self.actions, self.rewards, self.log_probs = [], [], [], []


In [7]:
class RL_CNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, fixed_length=300, kernel_num=100, kernel_size=[3, 4, 5]):
        super(RL_CNN, self).__init__()
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
                        
        self.embedding = nn.Embedding(self.vocab_size, self.embedding_dim)
        self.conv0 = nn.Conv2d(1, kernel_num, (kernel_size[0], self.embedding_dim))
        self.conv1 = nn.Conv2d(1, kernel_num, (kernel_size[1], self.embedding_dim))
        self.conv2 = nn.Conv2d(1, kernel_num, (kernel_size[2], self.embedding_dim))

        self.maxpools = [
            nn.MaxPool2d((fixed_length+1-i,1)) for i in kernel_size
        ]
#         self.linear = nn.Linear(len(kernel_size)*kernel_num, 1)
        
        
        _, s01, s02, s03 = self.conv0.weight.shape
        _, s11, s12, s13 = self.conv1.weight.shape
        _, s21, s22, s23 = self.conv2.weight.shape
        self.reinforce0 = REINFORCEAgent(s01 * s02 * s03, 2)
        self.reinforce1 = REINFORCEAgent(s11 * s12 * s13, 2)
        self.reinforce2 = REINFORCEAgent(s21 * s22 * s23, 2)
        
        self.stateNactions0 = []
        self.stateNactions1 = []
        self.stateNactions2 = []
        
        self.zeros0 = torch.FloatTensor(s01, s02, s03)
        self.zeros1 = torch.FloatTensor(s11, s12, s13)
        self.zeros2 = torch.FloatTensor(s21, s22, s23)
        self.dropout = nn.Dropout(0.5)

#         self.cnet = CNet()
        
    def set_embedding_weights(self, embedding_matrix):
        self.embedding.weight = nn.Parameter(embedding_matrix)
    
    def set_pretrained(self, pretrained):
        self.conv0.load_state_dict(pretrained[0])
        self.conv1.load_state_dict(pretrained[1])
        self.conv2.load_state_dict(pretrained[2])
        self.maxpools[0].load_state_dict(pretrained[3])
        self.maxpools[1].load_state_dict(pretrained[4])
        self.maxpools[2].load_state_dict(pretrained[5])

    def train_agent(self, correct):
        for i in range(len(self.stateNactions0)):
            (state0, (action0, log_prob0)) = self.stateNactions0[i]
            (state1, (action1, log_prob1)) = self.stateNactions1[i]
            (state2, (action2, log_prob2)) = self.stateNactions2[i]
            
            reward = 1 if correct else -1
            
            self.reinforce0.append_sample(state0, action0, reward, log_prob0)
            self.reinforce1.append_sample(state1, action1, reward, log_prob1)
            self.reinforce2.append_sample(state2, action2, reward, log_prob2)
            
        self.reinforce0.train_model()
        self.reinforce1.train_model()
        self.reinforce2.train_model()
        
    def forward(self, inp):
        x = self.embedding(inp).unsqueeze(1)
            
        k0 = self.conv0.weight.clone()
        k1 = self.conv1.weight.clone()
        k2 = self.conv2.weight.clone()

        self.stateNactions0 = [(state, self.reinforce0.get_action(state)) for state in k0]
        self.stateNactions1 = [(state, self.reinforce1.get_action(state)) for state in k1]
        self.stateNactions2 = [(state, self.reinforce2.get_action(state)) for state in k2]

        c0 = self.maxpools[0](torch.tanh(self.conv0(x))).squeeze(3).squeeze(2)
        c1 = self.maxpools[1](torch.tanh(self.conv1(x))).squeeze(3).squeeze(2)
        c2 = self.maxpools[2](torch.tanh(self.conv2(x))).squeeze(3).squeeze(2)

        cc0 = c0.clone()
        cc1 = c1.clone()
        cc2 = c2.clone()
        
        for i in range(c0[0].shape[0]):
            if self.stateNactions0[i][1][0] == 0: cc0[0][i] -= cc0[0][i]
            if self.stateNactions1[i][1][0] == 0: cc1[0][i] -= cc1[0][i]
            if self.stateNactions2[i][1][0] == 0: cc2[0][i] -= cc2[0][i]
                

        x = torch.cat([c0, c1, c2], dim=1)
        y = torch.cat([cc0.detach(), cc1.detach(), cc2.detach()], dim=1)
        x = self.dropout(x)
        return x, y

In [8]:
model = RL_CNN(
        vocab_size=vocab_size,
        embedding_dim=embedding_dim,
        fixed_length=300
).cuda()
for name, w in model.named_parameters():
    if 'weight' in name:
        nn.init.xavier_normal(w)
model.set_embedding_weights(embedding_matrix.cuda())
model.set_pretrained([torch.load(fname) for fname in ['conv0.pt', 'conv1.pt', 'conv2.pt', 'maxpool0.pt', 'maxpool1.pt', 'maxpool2.pt']])

In [9]:
cnet = CNet().cuda()

In [10]:
learning_rate = 0.05
num_epochs = 20

In [11]:
len(train_loader)

1500

In [12]:
histories = []
def train_model(num_epochs, val_index=0):
    hist = np.zeros(num_epochs*(len(train_loader)))
    j = 0
    val_indices = list(range(val_index, val_index + 50))
    for e in range(num_epochs):
        for i, (x, y) in enumerate(train_loader):
            if not i in val_indices:
                model.train()
                # Clear stored gradient
                model.zero_grad()
                x, y = x.cuda(), y.cuda()

                # Forward pass
                conv_out, policy_out = model(x.long())
                y_pred = cnet(policy_out)
                loss = loss_fn(y_pred.view(-1, 1), y.view(-1, 1))

                hist[j] = loss.item()
                
                for c_idx, correct in enumerate((y_pred.view(-1, 1).round() + y.view(-1, 1) - 1).abs() == 1):
                    model.train_agent(correct.item())
#                     conv_target[c_idx]
                
                # Zero out gradient, else they will accumulate between epochs
                optimiser.zero_grad()

                # Backward pass
                loss.backward()
                
                val_index += 50
                if val_index == len(train_loader):
                    val_index = 0
                # Update parameters
                optimiser.step()
                
            j += 1
        val_set = [(x, y) for x, y in train_loader][val_index:val_index+50]
        correct_num = 0
        with torch.no_grad():
            model.eval()
            no_kernels = []
            for x, y in val_set:
                x, y = x.cuda(), y.cuda()
#                 y_pred = model(x.long())
                conv_out, policy_out = model(x.long())
                act_ser1 = pd.Series([a for (_, (a, _)) in model.stateNactions0]).value_counts()
                act_ser2 = pd.Series([a for (_, (a, _)) in model.stateNactions1]).value_counts()
                act_ser3 = pd.Series([a for (_, (a, _)) in model.stateNactions2]).value_counts()
                no_kernels.append(sum([len(act_ser1[act_ser1 == 0]), len(act_ser2[act_ser2 == 0]), len(act_ser3[act_ser3 == 0])]))
                
                y_pred = cnet(policy_out)

                correct_num += (y_pred.view(-1, 1).round() + y.view(-1, 1) - 1).abs().item()
            print("epoch {}, val score: {}, kernels not used: {}".format(e + 1, 
                                                                         correct_num / len(val_set),
                                                                         sum(no_kernels)/len(no_kernels)
                                                                        ))


In [13]:
def test_score():
    full_num = 0
    correct_num = 0
    with torch.no_grad():
        model.eval()
        losses = []
        no_kernels = []
        for i, (x, y) in enumerate(test_loader):
            x, y = x.cuda(), y.cuda()
            # Forward pass
            conv_out, policy_out = model(x.long())
            
            act_ser1 = pd.Series([a for (_, (a, _)) in model.stateNactions0]).value_counts()
            act_ser2 = pd.Series([a for (_, (a, _)) in model.stateNactions1]).value_counts()
            act_ser3 = pd.Series([a for (_, (a, _)) in model.stateNactions2]).value_counts()
            no_kernels.append(sum([len(act_ser1[act_ser1 == 0]), len(act_ser2[act_ser2 == 0]), len(act_ser3[act_ser3 == 0])]))

            y_pred = cnet(policy_out)

            correct_num += (y_pred.view(-1, 1).round() + y.view(-1, 1) - 1).abs().item()
    print("score: {}, kernels not used: {}".format(correct_num / len(test_loader), sum(no_kernels)/len(no_kernels)))

In [14]:
loss_fn = torch.nn.BCELoss()
optimiser = torch.optim.Adadelta(cnet.parameters(), lr=learning_rate, weight_decay=.03)

train_model(60)
test_score()
# train_model(10)
# test_score()
# train_model(10, 500)
# test_score()
# train_model(10, 1000)
# test_score()
# train_model(10)
# test_score()
# train_model(10, 500)
# test_score()
# train_model(10, 1000)
# test_score()

epoch 1, val score: 0.9
epoch 2, val score: 0.96
epoch 3, val score: 0.94
epoch 4, val score: 0.88
epoch 5, val score: 0.9
epoch 6, val score: 0.88
epoch 7, val score: 0.92
epoch 8, val score: 0.86
epoch 9, val score: 0.9
epoch 10, val score: 0.92
0.804
epoch 1, val score: 0.98
epoch 2, val score: 0.92
epoch 3, val score: 0.9
epoch 4, val score: 0.94
epoch 5, val score: 0.86
epoch 6, val score: 0.96
epoch 7, val score: 0.94
epoch 8, val score: 0.92
epoch 9, val score: 0.92
epoch 10, val score: 0.86
0.8
epoch 1, val score: 0.88
epoch 2, val score: 0.92
epoch 3, val score: 0.96
epoch 4, val score: 0.92
epoch 5, val score: 0.92
epoch 6, val score: 0.96
epoch 7, val score: 0.94
epoch 8, val score: 0.96
epoch 9, val score: 0.82
epoch 10, val score: 0.9
0.804
epoch 1, val score: 0.9
epoch 2, val score: 0.94
epoch 3, val score: 0.92
epoch 4, val score: 0.96
epoch 5, val score: 0.96
epoch 6, val score: 0.9
epoch 7, val score: 0.92
epoch 8, val score: 0.94
epoch 9, val score: 0.94
epoch 10, val

In [15]:
train_model(10)
test_score()
train_model(10, 500)
test_score()
train_model(10, 1000)
test_score()
train_model(10)
test_score()
train_model(10, 500)
test_score()
train_model(10, 1000)
test_score()

epoch 1, val score: 1.0
epoch 2, val score: 0.9
epoch 3, val score: 0.88
epoch 4, val score: 0.9
epoch 5, val score: 0.96
epoch 6, val score: 0.96
epoch 7, val score: 0.94
epoch 8, val score: 0.92
epoch 9, val score: 0.92
epoch 10, val score: 0.92
0.792
epoch 1, val score: 0.92
epoch 2, val score: 0.94
epoch 3, val score: 0.92
epoch 4, val score: 0.86
epoch 5, val score: 0.88
epoch 6, val score: 0.9
epoch 7, val score: 0.94
epoch 8, val score: 0.92
epoch 9, val score: 0.9
epoch 10, val score: 0.9
0.802
epoch 1, val score: 0.88
epoch 2, val score: 0.92
epoch 3, val score: 0.92
epoch 4, val score: 0.82
epoch 5, val score: 0.96
epoch 6, val score: 0.92
epoch 7, val score: 0.94
epoch 8, val score: 0.88
epoch 9, val score: 0.9
epoch 10, val score: 0.9
0.8
epoch 1, val score: 0.9
epoch 2, val score: 0.94
epoch 3, val score: 0.92
epoch 4, val score: 0.9
epoch 5, val score: 0.96
epoch 6, val score: 0.9
epoch 7, val score: 0.96
epoch 8, val score: 0.86
epoch 9, val score: 0.88
epoch 10, val sco