In [1]:
import numpy as np
import random
from sklearn.metrics import f1_score
from sklearn.metrics import confusion_matrix
from sklearn.metrics.pairwise import cosine_similarity as cos_sim
from datetime import datetime

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical

In [None]:
model_name = 'facebook-bart-large'
cuda_device = 1

In [3]:
class BaseNetwork(nn.Module):
    
    def __init__(self, args):
        super(BaseNetwork, self).__init__()
        self.args = args
        self.layers = []
        i = 0
        for s1, s2 in zip(args.sizes[:-1], args.sizes[1:]):
            self.layers.append(nn.Linear(s1, s2))
            self.register_parameter('weight-layer-' + str(i), self.layers[-1].weight)
            self.register_parameter('bias-layer-' + str(i), self.layers[-1].bias)
            nn.init.xavier_uniform_(self.layers[-1].weight)
            i += 1
        self.nl = nn.Tanh()
        
    def forward(self, x):
        for layer in self.layers:
            x = self.nl(layer(x))
        return x

In [4]:
class TaskSpecificNetwork(nn.Module):
    
    def __init__(self, args, task_metadata=None):
        super(TaskSpecificNetwork, self).__init__()
        self.args = args
        self.base_network = BaseNetwork(args)
        self.out_layers = []
        i = 0
        for task in args.class_tasks:
            self.out_layers.append(nn.Linear(args.sizes[-1], task))
            self.register_parameter('out-weight-layer-' + str(i), self.out_layers[-1].weight)
            self.register_parameter('out-bias-layer-' + str(i), self.out_layers[-1].bias)
            nn.init.xavier_uniform_(self.out_layers[-1].weight)
            i += 1
        self.nl = nn.Tanh()
        
    def forward(self, x, task_id):
        rep_x = self.base_network(x)
        x = self.out_layers[task_id](self.nl(rep_x))
        out = F.softmax(x, dim=1)
        return rep_x, out

In [5]:
class TaskSpecificNetworkMultilabel(nn.Module):
    
    def __init__(self, args, task_metadata=None):
        super(TaskSpecificNetworkMultilabel, self).__init__()
        self.args = args
        self.base_network = BaseNetwork(args)
        self.out_layers = []
        i = 0
        for task, task_type in zip(args.class_tasks, args.class_types):
            self.out_layers.append(nn.Linear(args.sizes[-1], task))
            self.register_parameter('out-weight-layer-' + str(i), self.out_layers[-1].weight)
            self.register_parameter('out-bias-layer-' + str(i), self.out_layers[-1].bias)
            nn.init.xavier_uniform_(self.out_layers[-1].weight)
            i += 1
        self.nl = nn.Tanh()
        
    def forward(self, x, task_id):
        rep_x = self.base_network(x)
        x = self.out_layers[task_id](self.nl(rep_x))
        if self.args.class_types[task_id] == 'classification':
            out = F.softmax(x, dim=1)
        elif self.args.class_types[task_id] == 'multilabel':
            out = torch.sigmoid(x)
        return rep_x, out

In [6]:
class ActorNetwork(nn.Module):
    
    def __init__(self, args):
        super(ActorNetwork, self).__init__()
        self.args = args
        self.actor = nn.Linear(args.state_dim, 2)
        nn.init.xavier_uniform_(self.actor.weight)
        self.nl = nn.Tanh()
        
    def forward(self, x):
        action_probs = F.softmax(self.actor(x), dim=-1)
        return action_probs

In [7]:
class ActorNetwork(nn.Module):
    
    def __init__(self, args):
        super(ActorNetwork, self).__init__()
        self.args = args
        self.actor = nn.Linear(args.state_dim, 2)
        nn.init.xavier_uniform_(self.actor.weight)
        self.nl = nn.Tanh()
        
    def forward(self, x):
        action_probs = F.softmax(self.actor(x), dim=-1)
        return action_probs

In [8]:
class CriticNetwork(nn.Module):
    
    def __init__(self, args):
        super(CriticNetwork, self).__init__()
        self.args = args
        self.task_model = TaskSpecificNetworkMultilabel(args)
        self.ff1 = nn.Linear(args.state_dim, args.hdim)
        nn.init.xavier_uniform_(self.ff1.weight)
        self.critic_layer = nn.Linear(args.hdim, 1)
        nn.init.xavier_uniform_(self.critic_layer.weight)
        self.nl = nn.Tanh()
        
    def forward(self, x):
        x_out = self.task_model.base_network(x).detach()
        c_in = self.nl(self.ff1(x_out))
        out = torch.sigmoid(self.critic_layer(c_in))
        out = torch.mean(out)
        return x_out, out
    
    def task_output(self, x, task_id):
        return self.task_model(x, task_id)

In [9]:
class PolicyNetwork(nn.Module):
    
    def __init__(self, args):
        super(PolicyNetwork, self).__init__()
        self.args = args
        self.actor = ActorNetwork(args)
        self.critic = CriticNetwork(args)
        self.task_optims = []
        self.loss_fns = []
        
        for i in range(args.num_tasks):
            params = []
            base_params = [p for p in self.critic.task_model.base_network.parameters() if p.requires_grad]
            task_params = [p for p in self.critic.task_model.out_layers[i].parameters() if p.requires_grad]
            params = base_params + task_params
            self.task_optims.append(optim.Adam(params, lr=args.lr))
            if args.class_types[i] == 'classification':
                if args.use_cuda:
                    with torch.cuda.device(cuda_device):
                        args.loss_weights[i] = args.loss_weights[i].cuda()
                self.loss_fns.append(nn.CrossEntropyLoss(weight=args.loss_weights[i]))
            elif args.class_types[i] == 'multilabel':
                self.loss_fns.append(nn.BCELoss())
        self.saved_actions = []
        self.rewards = []
        
    def forward(self, batch_x):
        if self.args.use_cuda:
            with torch.cuda.device(cuda_device):
                batch_x = batch_x.cuda()
        batch_rep, exp_reward = self.critic(batch_x)
        action_probs = self.actor(batch_rep)
        return action_probs, batch_rep, exp_reward
        
    def compute_reward(self, eval_data):
        pred_ys = []
        data_ys = []
        for eval_batch in eval_data:
            batch_x, batch_y = eval_batch
            if self.args.use_cuda:
                with torch.cuda.device(cuda_device):
                    batch_x = batch_x.cuda()
                    batch_y = batch_y.cuda()
            _, pred_out = self.critic.task_output(batch_x, -1)
            if self.args.class_types[-1] == 'classification':
                pred_y = torch.argmax(pred_out, dim=1)
            elif self.args.class_types[-1] == 'multilabel':
                pred_y = (pred_out >= 0.5).long()
            pred_ys.append(pred_y)
            data_ys.append(batch_y)
        pred_Y = torch.cat(pred_ys, dim=0)
        data_Y = torch.cat(data_ys, dim=0)
        f1_ma = float(f1_score(data_Y.cpu().data, pred_Y.cpu().data, average='macro'))
        return f1_ma
        
    def train_minibatch(self, batch_x, batch_y, task_id):
        if self.args.use_cuda:
            with torch.cuda.device(cuda_device):
                batch_x = batch_x.cuda()
                batch_y = batch_y.cuda()
        _, batch_out = self.critic.task_output(batch_x, task_id)
        batch_loss = self.loss_fns[task_id](batch_out, batch_y)
        self.critic.task_model.zero_grad()
        batch_loss.backward()
        self.task_optims[task_id].step()

In [10]:
eps = np.finfo(np.float32).eps.item()
def finish_episode(policy_model, data_sets, eval_data):
    
    for dnum, dset in enumerate(data_sets):
        train, dev, test = dset
        for train_batch in train:
            batch_x, batch_y = train_batch
            action_probs, batch_rep, exp_reward = policy_model(batch_x)
            m = Categorical(action_probs)
            action = m.sample()
            batch_mask = action.cpu() == torch.ones(action.size())
            sel_x = batch_x[batch_mask, :]
            sel_y = batch_y[batch_mask]
            policy_model.train_minibatch(sel_x, sel_y, dnum)
            reward = policy_model.compute_reward(eval_data)
            policy_model.saved_actions.append((m.log_prob(action), exp_reward))
            policy_model.rewards.append(reward)
            
        policy_losses = []
        value_losses = []

        R_mean = np.mean(policy_model.rewards)
        R_std = np.std(policy_model.rewards)
        for i, r in enumerate(policy_model.rewards):
            policy_model.rewards[i] = float((r - R_mean) / (R_std + eps))

        for (log_prob, value), R in zip(policy_model.saved_actions, policy_model.rewards):
            advantage = R - value.item()

            # calculate actor (policy) loss 
            policy_losses.append(-log_prob * advantage)

            # calculate critic (value) loss using L1 smooth loss
            R_tensor = torch.tensor([R])
            if policy_model.args.use_cuda:
                with torch.cuda.device(cuda_device):
                    R_tensor = R_tensor.cuda()
            value_losses.append(F.smooth_l1_loss(value.view(1,), R_tensor))

        # reset gradients
        optimizer.zero_grad()

        # sum up all the values of policy_losses and value_losses
        loss = torch.cat(policy_losses, dim=0).sum() + torch.stack(value_losses).sum()

        # perform backprop
        loss.backward()
        optimizer.step()

        # reset rewards and action buffer
        del policy_model.rewards[:]
        del policy_model.saved_actions[:]

In [11]:
def finish_mtl_epoch(mtl_model, data_sets, eval_data):
    for dnum, dset in enumerate(data_sets):
        train, dev, test = dset
        for train_batch in train:
            batch_x, batch_y = train_batch
            mtl_model.train_minibatch(batch_x, batch_y, dnum)

In [12]:
class Args():
    
    def __init__(self):
        self.sizes = [300, 500, 100]
        self.hdim = 50
        self.state_dim = 100
        self.class_tasks = [2, 2, 2, 2, 2]
        self.num_tasks = 5
        self.lr = 1e-2
        self.use_cuda = True
        self.class_types = ['classification'] * 5
        
args1 = Args()

## Real Data Experiments

### Loading Data

In [13]:
import pickle
dpath = '/scratch1/rpujari/gcr_workspace/data/'
with open(dpath + 'batched_dsets_multilabel_' + model_name + 'bsz64.pkl', 'rb') as infile:
    batched_dsets = pickle.load(infile)
print(batched_dsets.keys())

dict_keys(['jigsaw-dataset', 'hate-speech-dataset', 'hate-speech-and-offensive-language', 'ami-ibereval-dataset', 'stereotype'])


In [14]:
dsets = []
for label in batched_dsets:
    if label != 'stereotype':
        dsets.append(batched_dsets[label])
dsets.append(batched_dsets['stereotype'])
trimmed_dsets = []
nbatches = 300 #Change to 2000 to run experiments on full data 
for dset in dsets:
    trimmed_dsets.append((dset[0][:nbatches], dset[1][:nbatches], dset[2][:nbatches]))
edata = batched_dsets['ami-ibereval-dataset'][1]

In [15]:
for dname in batched_dsets:
    print(dname, len(batched_dsets[dname][0]))

jigsaw-dataset 1995
hate-speech-dataset 120
hate-speech-and-offensive-language 271
ami-ibereval-dataset 36
stereotype 263


In [16]:
class_tasks = [6, 2, 3, 2, 2]
num_tasks = 5
lr = 3e-5
class_types = ['multilabel', 'classification', 'classification', 'classification', 'classification']

In [17]:
#computing loss weights based on class distributions in the training data
o = 0
l = 0
loss_weights = [None]
for dnum, dset in enumerate(trimmed_dsets):
    if dnum > 0:
        tr, de, te = dset
        ldict = {}
        t = 0
        loss_weight = torch.ones(class_tasks[dnum]).float()
        data_ys = []
        for batch in de:
    #         print(batch[1].size())
            for i in range(batch[1].size(0)):
                if batch[1][i].data.item() not in ldict:
                    ldict[batch[1][i].data.item()] = 0
                ldict[batch[1][i].data.item()] += 1
                t += 1
            data_ys.append(batch[1])
        data_y = torch.cat(data_ys, dim=0)
        pred_y = torch.ones(data_y.size())
        print(t, ldict)
        for i in range(class_tasks[dnum]):
            loss_weight[i] = t / ldict[i]
        loss_weights.append(loss_weight)
        f1_ma = f1_score(data_y.cpu().data, pred_y.cpu().data, average='macro') 
        print(f1_ma)

1670 {0: 1471, 1: 199}
0.10647405029427502
3746 {1: 2905, 2: 624, 0: 217}
0.2911842830652032
502 {1: 264, 0: 238}
0.34464751958224543
3478 {0: 2219, 1: 1259}
0.26578002955457036


### Baseline Results

In [18]:
class FeedForward(nn.Module):
    
    def __init__(self, sizes, use_cuda, lr, momentum, weight_decay, loss_weight):
        super(FeedForward, self).__init__()
        self.layers = []
        i = 0
        for s1, s2 in zip(sizes[:-1], sizes[1:]):
            self.layers.append(nn.Linear(s1, s2))
            self.register_parameter('weight-layer-' + str(i), self.layers[-1].weight)
            self.register_parameter('bias-layer-' + str(i), self.layers[-1].bias)
            nn.init.xavier_uniform_(self.layers[-1].weight)
            i += 1
        self.loss_fn = nn.CrossEntropyLoss(weight=loss_weight)
        self.nl = nn.Tanh()
        params = [p for p in self.parameters() if p.requires_grad]
        self.optimizer = optim.Adam(params, lr=lr)
        self.use_cuda = use_cuda
        
    def forward(self, x):
        for layer in self.layers:
            x = self.nl(layer(x))
        out = F.softmax(x, dim=1)
        return out
        
    def evaluate(self, data):
        self.eval()
        batch_outs = []
        batch_preds = []
        batch_ys = []
        for batch_x, batch_y in data:
            if self.use_cuda:
                with torch.cuda.device(cuda_device):
                    batch_x = batch_x.cuda()
                    batch_y = batch_y.cuda()
            batch_out = self.forward(batch_x)
            batch_pred = torch.argmax(batch_out, dim=1)
            batch_outs.append(batch_out)
            batch_preds.append(batch_pred)
            batch_ys.append(batch_y)
        pred_y = torch.cat(batch_preds, dim=0)
        data_y = torch.cat(batch_ys, dim=0)
        pred_out = torch.cat(batch_outs, dim=0)
        acc = sum((pred_y == data_y).float()) / data_y.size(0)
        f1_mi = f1_score(data_y.cpu().data, pred_y.cpu().data, average='micro')
        f1_ma = f1_score(data_y.cpu().data, pred_y.cpu().data, average='macro')
        con_mat = confusion_matrix(data_y.cpu().data, pred_y.cpu().data)
        val_loss = self.loss_fn(pred_out, data_y)
        return acc.cpu(), val_loss.cpu(), (f1_mi, f1_ma, con_mat)
        
    def train_model(self, train, val, num_epochs=25, save_path='./model.pkl'):
        self.train()
        for i in range(num_epochs):
            max_val = -1
            for batch_x, batch_y in train:
                if self.use_cuda:
                    with torch.cuda.device(cuda_device):
                        batch_x = batch_x.cuda()
                        batch_y = batch_y.cuda()
                ids = list(range(batch_x.size(0)))
                random.shuffle(ids)
                batch_x = batch_x[ids, :]
                batch_y = batch_y[ids]
                batch_out = self.forward(batch_x)
                loss = self.loss_fn(batch_out, batch_y)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
            val_acc, val_loss, (f1_mi, f1_ma, cm) = self.evaluate(val)
            if f1_ma > max_val:
                max_val = f1_ma
                torch.save(self.state_dict(), save_path)

In [19]:
class Args2():
    
    def __init__(self):
        self.edim = 1024
        self.tdim = 300
        self.hdim = 50
        self.ddim = 20
        self.num_deps = 78
        self.padding_idx = 0
        self.sizes = [1024, 300, 100, 2]
        self.use_cuda = True
        self.lr = 3e-5
        self.momentum = 0.1
        self.weight_decay = 0
        self.device = cuda_device
        self.loss_weight = torch.from_numpy(np.array([1., 1.])).type(torch.FloatTensor)
        
args2 = Args2()

In [20]:
t1 = datetime.now()
num_epochs = 100
for i in range(1, len(trimmed_dsets)):
    dset = trimmed_dsets[i]
    tr, de, te = dset
    torch.manual_seed(13)
    args2.sizes[-1] = class_tasks[i]
    args2.loss_weight = loss_weights[i]
    print(args2.sizes)
    model = FeedForward(args2.sizes, args2.use_cuda, args2.lr, args2.momentum, args2.weight_decay, args2.loss_weight)
    with torch.cuda.device(cuda_device):
        model.cuda(cuda_device)

    de = dset[1]
    te = dset[2]
    model.train_model(tr, de, num_epochs=num_epochs, save_path=dpath + 'trained_models/base_model' + str(i) + '_' + model_name + '.pkl')
    model.load_state_dict(torch.load(dpath + 'trained_models/base_model' + str(i) + '_' + model_name + '.pkl'))
    res = model.evaluate(de)
#     con_mat = res[-1][-1]
#     acc = res[0]
#     prec_1 = con_mat[1, 1] / (con_mat[1, 1] + con_mat[0, 1])
#     rec_1 = con_mat[1, 1] / (con_mat[1, 1] + con_mat[1, 0])
#     f1_1 = 2 * prec_1 * rec_1 / (prec_1 + rec_1)
    t2 = datetime.now()
    print(res[2][1], t2 - t1, '\n')#, (round(acc.data.item(), 4), round(prec_1, 4), round(rec_1, 4), round(f1_1, 4)))    

[1024, 300, 100, 2]
0.5717471772303054 0:00:22.237974 

[1024, 300, 100, 3]
0.4540735866182117 0:01:09.511269 

[1024, 300, 100, 2]
0.6061129151024252 0:01:16.227762 

[1024, 300, 100, 2]
0.6262672507943621 0:01:56.381867 



### Multi-task Learning Results

In [21]:
args1.sizes = [1024, 300, 100]
args1.class_tasks = [6, 2, 3, 2, 2]
args1.num_tasks = 5
args1.lr = 3e-5
args1.class_types = ['multilabel', 'classification', 'classification', 'classification', 'classification']
args1.loss_weights = loss_weights 

In [22]:
t1 = datetime.now()
num_ep = 200

#For each dataset in the set of datasets
for i in range(1, len(trimmed_dsets)):
    
    #evaluation dataset - dev data of selected dataset i
    edata = trimmed_dsets[i][1]
    #test dataset - test data of selected dataset i
    tdata = trimmed_dsets[i][2]
    
    #create arguments by excluding the dataset in focus
    sample_dsets = trimmed_dsets[:i] + trimmed_dsets[i+1:]# + trimmed_dsets[i]
    args1.class_tasks = class_tasks[:i] + class_tasks[i+1:]# + class_tasks[i]
    args1.class_types = class_types[:i] + class_types[i+1:]# + class_types[i]
    args1.loss_weights = loss_weights[:i] + loss_weights[i+1:]# + loss_weights[i]
    args1.num_tasks = len(sample_dsets)
    args1.prefix = '_p1_dset' + str(i) + '_' + model_name 
    
    #initialize a multi-task model
    torch.manual_seed(13)
    random.seed(13)
    with torch.cuda.device(cuda_device):
        model = PolicyNetwork(args1)
        model.cuda(cuda_device)
        optimizer = optim.Adadelta(model.parameters(), args1.lr)
    
    #Train on all datasets - dataset i, with eval data as the dev data of the last dataset
    max_reward = -1
    max_epoch = -1
    for epoch in range(num_ep):
        finish_mtl_epoch(model, sample_dsets, edata)
        t2 = datetime.now()
        Re = model.compute_reward(sample_dsets[-1][1])
        if Re > max_reward:
            torch.save(model.state_dict(), dpath + 'trained_models/mtl_model' + args1.prefix + '.pkl')
            max_reward = Re
            max_epoch = epoch
#         print('Epoch', epoch, Re, t2-t1)
    
    #create arguments with only dataset i
    sample_dsets = [trimmed_dsets[i]]
    args1.class_tasks = [class_tasks[i]]
    args1.class_types = [class_types[i]]
    args1.loss_weights = [loss_weights[i]]
    args1.num_tasks = len(sample_dsets)
    args1.prefix = '_p2_dset' + str(i) + '_' + model_name 
    
    #initialize a new model for just task i
    torch.manual_seed(13)
    random.seed(13)
    with torch.cuda.device(cuda_device):
        model = PolicyNetwork(args1)
        model.cuda(cuda_device)
        optimizer = optim.Adadelta(model.parameters(), args1.lr)
    
    #extract the base network parameters from the trained multi-task model
    state_dict = torch.load(dpath + 'trained_models/mtl_model_p1_dset' + str(i) + '_' + model_name + '.pkl')
    del_keys = []
    for key in state_dict:
        if key.startswith('critic.task_model.out-'):
            del_keys.append(key)
    for key in del_keys:
        del state_dict[key]
    #initialize with best model from multi-task learning phase and evaluate the model on dataset i
    model.load_state_dict(state_dict, strict=False)
    print(model.compute_reward(edata))
    
    #train the model with training data - train data i and eval data as dev data i
    max_reward = -1
    max_epoch = -1
    for epoch in range(num_ep):
        finish_mtl_epoch(model, sample_dsets, edata)
        t2 = datetime.now()
        Re = model.compute_reward(edata)
        if Re > max_reward:
            torch.save(model.state_dict(), dpath + 'trained_models/mtl_model' + args1.prefix + '.pkl')
            max_reward = Re
            max_epoch = epoch
#         print('Epoch', epoch, Re, t2-t1)
    
    #load best model parameters and evluate the traine model
    model.load_state_dict(torch.load(dpath + 'trained_models/mtl_model' + args1.prefix + '.pkl'))
    o = 0
    l = 0
    for batch in tdata:
        o += torch.sum(batch[1]).data.item()
        l += len(batch[1])
    print(o, l, o/l)
    t2 = datetime.now()
    print('Test Performance - ', i, ': ', model.compute_reward(edata), t2-t1, '\n\n')

0.46670950676893264
211 1648 0.12803398058252427
Test Performance -  1 :  0.5963610175733979 0:04:02.817876 


0.21875870009219686
4115 3712 1.1085668103448276
Test Performance -  2 :  0.47279802631471685 0:08:12.488411 


0.445126151381658
225 498 0.45180722891566266
Test Performance -  3 :  0.5824495737159492 0:12:27.231444 


0.4762084306365222
1276 3610 0.35346260387811634
Test Performance -  4 :  0.6857974107269882 0:16:44.239686 




### RL Learning Results

In [23]:
t1 = datetime.now()
num_ep = 100

#For each dataset in the set of datasets
for i in range(1, len(trimmed_dsets)):
    
    #evaluation dataset - dev data of selected dataset i
    edata = trimmed_dsets[i][1]
    #test dataset - test data of selected dataset i
    tdata = trimmed_dsets[i][2]
    
    #create arguments by placing the target dataset as the last task
    sample_dsets = trimmed_dsets[:i] + trimmed_dsets[i+1:] + [trimmed_dsets[i]]
    args1.class_tasks = class_tasks[:i] + class_tasks[i+1:] + [class_tasks[i]]
    args1.class_types = class_types[:i] + class_types[i+1:] + [class_types[i]]
    args1.loss_weights = loss_weights[:i] + loss_weights[i+1:] + [loss_weights[i]]
    args1.num_tasks = len(sample_dsets)
    args1.prefix = '_p3_dset' + str(i) + '_' + model_name 
    
    #initialize the policy network
    torch.manual_seed(13)
    random.seed(13)
    with torch.cuda.device(cuda_device):
        model = PolicyNetwork(args1)
        model.cuda(cuda_device)
        optimizer = optim.Adadelta(model.parameters(), args1.lr)
    
    #extract the task model parameters from the trained multi-task models
    state_dict2 = torch.load(dpath + 'trained_models/mtl_model_p2_dset' + str(i) + '_' + model_name + '.pkl')
    state_dict1 = torch.load(dpath + 'trained_models/mtl_model_p1_dset' + str(i) + '_' + model_name + '.pkl')
    state_dict = model.state_dict()
    
    done = -1
    #initialize all the sister task parameters from stage 1 mtl model
    for key in state_dict1:
        if key.startswith('critic.task_model.out-'):
            task_id = int(key.split('-')[-1])
            if task_id > done:
                done = task_id
            state_dict[key] = state_dict1[key]
    #initialize target task parameters from stage 2 mtl model
    for key in state_dict2:
        if key.startswith('critic.task_model.out-'):
            ksplit = key.split('-')
            #change the target task id from 0 (in stage 2 mtl model) to (num_tasks - 1)
            ksplit[-1] = str(done + 1)
            state_dict['-'.join(ksplit)] = state_dict2[key]
        elif key.startswith('critic.task_model'):
            state_dict[key] = state_dict2[key]
    model.load_state_dict(state_dict)
    
    #Train RL model on all datasets, with eval data - dev data of target task ( dataset i)
    max_reward = -1
    max_epoch = -1
    for epoch in range(num_ep):
        finish_episode(model, sample_dsets, edata)
        t2 = datetime.now()
        Re = model.compute_reward(edata)
        if Re > max_reward:
            torch.save(model.state_dict(), dpath + 'trained_models/rl_model' + args1.prefix + '.pkl')
            max_reward = Re
            max_epoch = epoch
#         print('Epoch', epoch, Re, t2-t1)
    
    #create arguments for dataset i
    sample_dsets = [trimmed_dsets[i]]
    args1.class_tasks = [class_tasks[i]]
    args1.class_types = [class_types[i]]
    args1.loss_weights = [loss_weights[i]]
    args1.num_tasks = len(sample_dsets)
    args1.prefix = '_p4_dset' + str(i) + '_' + model_name 
    
    #initialize a new model for just task i
    torch.manual_seed(13)
    random.seed(13)
    with torch.cuda.device(cuda_device):
        model = PolicyNetwork(args1)
        model.cuda(cuda_device)
        optimizer = optim.Adadelta(model.parameters(), args1.lr)
    
    #extract the parameters of target task from stage 3 RL model
    state_dict = torch.load(dpath + 'trained_models/rl_model_p3_dset' + str(i) + '_' + model_name + '.pkl')
    del_keys = []
    #remove paramters of other tasks from state_dict
    for key in state_dict:
        if key.startswith('critic.task_model.out-'):
            del_keys.append(key)
    for key in del_keys:
        if not key.endswith(str(done + 1)):
            del state_dict[key]
    for key in del_keys:
        if key.endswith(str(done + 1)):
            #rename target task paramters from num_tasks-1 to 0 
            ksplit = key.split('-')
            ksplit[-1] = '0'
            state_dict['-'.join(ksplit)] = state_dict[key]
            del state_dict[key]
            
    #initialize with best model from stage 3 and evaluate the model on dev data of dataset i
    model.load_state_dict(state_dict)
    print(model.compute_reward(edata))
    
    #train the model on target task
    max_reward = -1
    max_epoch = -1
    for epoch in range(num_ep):
        finish_episode(model, sample_dsets, edata)
        t2 = datetime.now()
        Re = model.compute_reward(edata)
        if Re > max_reward:
            torch.save(model.state_dict(), dpath + 'trained_models/rl_model' + args1.prefix + '.pkl')
            max_reward = Re
            max_epoch = epoch
#         print('Epoch', epoch, Re, t2-t1)
    
    #load best model parameters from stage 4 and evluate the trained model
    model.load_state_dict(torch.load(dpath + 'trained_models/rl_model' + args1.prefix + '.pkl'))
    o = 0
    l = 0
    for batch in tdata:
        o += torch.sum(batch[1]).data.item()
        l += len(batch[1])
    print(o, l, o/l)
    t2 = datetime.now()
    print('Test Performance - ', i, ': ', model.compute_reward(edata), t2-t1, '\n\n')

0.5806905031760894
211 1648 0.12803398058252427
Test Performance -  1 :  0.5760629112931797 0:26:30.356505 


0.46456490251260485
4115 3712 1.1085668103448276
Test Performance -  2 :  0.473264962789864 1:20:58.228438 


0.5709812942285267
225 498 0.45180722891566266
Test Performance -  3 :  0.5856351608388262 1:36:40.073597 


0.680348173113854
1276 3610 0.35346260387811634
Test Performance -  4 :  0.6850467901315358 2:48:28.480076 




### Test Data Performance

In [24]:
print(batched_dsets.keys())

dict_keys(['jigsaw-dataset', 'hate-speech-dataset', 'hate-speech-and-offensive-language', 'ami-ibereval-dataset', 'stereotype'])


In [25]:
for i in range(1, len(trimmed_dsets)):
    dset = trimmed_dsets[i]
    args1.class_tasks = [class_tasks[i]]
    args1.class_types = [class_types[i]]
    args1.loss_weights = [loss_weights[i]]
    args1.num_tasks = 1
    
    print('Task ID:', i, list(batched_dsets.keys())[i])
    with torch.cuda.device(cuda_device) and torch.no_grad():
        tdata = dset[2]
        
        args2.sizes[-1] = class_tasks[i]
        args2.loss_weight = loss_weights[i]
        model = FeedForward(args2.sizes, args2.use_cuda, args2.lr, args2.momentum, args2.weight_decay, args2.loss_weight)
        model.load_state_dict(torch.load(dpath + 'trained_models/base_model' + str(i) + '_' + model_name + '.pkl'))
        model.cuda(cuda_device)
        print('Baseline:', model.evaluate(tdata)[2][1] * 100)
        
        model = PolicyNetwork(args1)
        model.cuda(cuda_device)

        model.load_state_dict(torch.load(dpath + 'trained_models/mtl_model_p2_dset' + str(i) + '_' + model_name + '.pkl'))
        print('Multi-task:', model.compute_reward(tdata) * 100)

        model.load_state_dict(torch.load(dpath + 'trained_models/rl_model_p4_dset' + str(i) + '_' + model_name + '.pkl'))
        print('RL:', model.compute_reward(tdata) * 100)
        print('\n')

Task ID: 1 hate-speech-dataset
Baseline: 59.135914215660776
Multi-task: 60.132416603004835
RL: 60.33338955416877


Task ID: 2 hate-speech-and-offensive-language
Baseline: 48.330148764224106
Multi-task: 47.100267438651144
RL: 47.20860623456615


Task ID: 3 ami-ibereval-dataset
Baseline: 63.15867345860371
Multi-task: 63.93407142262549
RL: 62.534722222222214


Task ID: 4 stereotype
Baseline: 63.707305154804715
Multi-task: 67.87178307690164
RL: 68.20354359523967




### Mturk Annotated Data Zero-Shot Results

In [28]:
with open(dpath + 'mturk_batched_dsets_multilabel_' + model_name + '_bsz64.pkl', 'rb') as infile:
    mturk_batched_dsets = pickle.load(infile)
print(mturk_batched_dsets.keys())

dict_keys(['stereotype-gold-binary', 'stereotype-gold-multilabel'])


In [29]:
dset = mturk_batched_dsets['stereotype-gold-binary']
args1.class_tasks = [2]
args1.class_types = ['classification']
#placeholder, not used in inference
args1.loss_weights = [loss_weights[-1]]
args1.num_tasks = 1
i = 4

with torch.cuda.device(cuda_device) and torch.no_grad():
    tdata = dset[2]

    args2.sizes[-1] = 2
    #placeholder, not used in inference
    args2.loss_weight = loss_weights[-1]
    model = FeedForward(args2.sizes, args2.use_cuda, args2.lr, args2.momentum, args2.weight_decay, args2.loss_weight)
    model.load_state_dict(torch.load(dpath + 'trained_models/base_model' + str(i) + '_' + model_name + '.pkl'))
    model.cuda(cuda_device)
    print('Baseline:', model.evaluate(tdata)[2][1] * 100)

    model = PolicyNetwork(args1)
    model.cuda(cuda_device)

    model.load_state_dict(torch.load(dpath + 'trained_models/mtl_model_p2_dset' + str(i) + '_' + model_name + '.pkl'))
    print('Multi-task:', model.compute_reward(tdata) * 100)

    model.load_state_dict(torch.load(dpath + 'trained_models/rl_model_p4_dset' + str(i) + '_' + model_name + '.pkl'))
    print('RL:', model.compute_reward(tdata) * 100)
    print('\n')

Baseline: 53.80277617119722
Multi-task: 57.00245700245701
RL: 56.36545636545638


