In [1]:
import json
import torch
import pickle
import numpy as np
import torch.nn as nn
from scipy import stats
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from nn_utils import Attention, EmotionRegression, FeatureExtaction

In [2]:
batch_size = 64
lr_attn = 1e-5
lr_feature = 8e-6
lr_regressor = 4e-6
# lr_discriminator = 4e-6
epochs = 10
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(device)

cuda:0


In [3]:
dict = {}
with open("dataset.pkl", "rb") as fin :
    dict = pickle.load(fin)

vocab = dict['vocab'].to(device)
vocab.requires_grad = False
train_dataloader = DataLoader(
    dict['train_dataset'],
    shuffle=True,
    batch_size=batch_size
    )
val_dataloader = DataLoader(
    dict['dev_dataset'],
    shuffle=False,
    batch_size=batch_size
    )

In [4]:
class AAN(nn.Module) :
    def __init__(self, embed_size=300, hidden_size=150) :
        #['V', 'A', 'D', 'S']
        super(AAN, self).__init__()
        self.embed_size = embed_size
        self.hidden_size = hidden_size
        self.feature_size = self.hidden_size*4

        self.attention_1 = Attention(self.embed_size)
        self.attention_2 = Attention(self.embed_size)
        self.attention_s = Attention(self.embed_size)

        self.features = FeatureExtaction(self.embed_size, self.hidden_size)
        
        self.regression_1 = EmotionRegression(self.feature_size * 2)
        self.regression_2 = EmotionRegression(self.feature_size * 2)

#         self.discriminator = Discriminator(self.feature_size)
    
    def forward(self, vocab, sentences, source_lengths) :
        sentences = vocab(sentences)
        sentences = sentences.detach()

        sentences_1 = self.attention_1(sentences, source_lengths)
        sentences_2 = self.attention_2(sentences, source_lengths)
        sentences_s = self.attention_s(sentences, source_lengths)

        features_1 = self.features(sentences_1, source_lengths)
        features_2 = self.features(sentences_2, source_lengths)
        features_s = self.features(sentences_s, source_lengths)

        value_1 = self.regression_1(torch.cat((features_1, features_s), dim=1))
        value_2 = self.regression_2(torch.cat((features_2, features_s), dim=1))

#         p1, p2 = self.discriminator(features_1), self.discriminator(features_2)
        
        return value_1, value_2#, p1, p2


In [5]:
class Train() :
    def __init__(self, type=[0,1]) :
        super(Train, self).__init__()
        self.type = type
        self.model = AAN().to(device)
        # self.init_weights()
        self.mse = nn.MSELoss()
        self.attention_optim = optim.Adam(
            list(self.model.attention_1.parameters())+
            list(self.model.attention_2.parameters())+
            list(self.model.attention_s.parameters()), lr=lr_attn)
        self.attention_optim_adversarial = optim.Adam(
            list(self.model.attention_1.parameters())+
            list(self.model.attention_2.parameters())+
            list(self.model.attention_s.parameters()), lr=lr_attn)
        self.feature_optim = optim.Adam(self.model.features.parameters(), lr=lr_feature)
        self.regressor_optim = optim.RMSprop(
            list(self.model.regression_1.parameters())+
            list(self.model.regression_1.parameters()), lr=lr_regressor)
#         self.discriminator_optim = optim.RMSprop(self.model.discriminator.parameters(), lr=lr_discriminator)
        self.training_stats = []
    
    def init_weights(self) :
        for param in self.model.parameters() :
            temp = np.sqrt(6.0/(sum([i for i in param.shape])+1e-8))
            param.data.uniform_(-temp, temp)
    
    def run_model(self, batch) :
        sentences = batch[0].to(device)
        source_lengths = batch[1].to(device)
        target = batch[2].to(device)
        value_1, value_2 = self.model(vocab, sentences, source_lengths)
        output = torch.cat((value_1, value_2), dim=1)
        target = target[:,self.type].float().detach()
        return output, target# p1, p2, target
    
    def get_r(self, output, target) :
        return [0, 0]
        temp = [(stats.pearsonr(output[:,i].cpu().detach(), target[:,i].cpu().detach())) for i in range(2)]
        return temp
    
    def train(self, epochs=epochs) :
        mse = nn.MSELoss()
        for epoch_i in range(epochs) :
            print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
            self.model.train()
            train_r = []
            train_loss = []
            for i, batch in enumerate(train_dataloader) :

                self.attention_optim.zero_grad()
                self.feature_optim.zero_grad()
                self.regressor_optim.zero_grad()

                output, target = self.run_model(batch)
                reg_loss = mse(output, target)
                train_loss.append(reg_loss.item())
                reg_loss.backward()
#                 torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)

                self.attention_optim.step()
                self.feature_optim.step()
                self.regressor_optim.step()

                # self.discriminator_optim.zero_grad()

                # output, p1, p2, target = self.run_model(batch)
                # wloss = (p2-p1).mean()
                # wloss.backward()
                # torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                
                # self.discriminator_optim.step()

                # self.attention_optim_adversarial.zero_grad()
                
                # output, p1, p2, target = self.run_model(batch)
                # adversarial_loss = (p1-p2).mean()
                # adversarial_loss.backward()
                # torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                
                # self.attention_optim_adversarial.step()

                train_r.append(self.get_r(output, target))
                if i%10 == 0 :
                    print("Batch: {} train Loss: {} train_r: {}".format(i, reg_loss, train_r[-1]))

            val_loss = []
            val_r = []
            for i, batch in enumerate(val_dataloader) :
                with torch.no_grad() :
                    output, target = self.run_model(batch)

                    reg_loss = mse(output, target)
                    val_loss.append(reg_loss.item())
                    val_r.append(self.get_r(output, target))
                    if i%10 == 0 :
                        print("Batch: {} val Loss: {} val_r: {}".format(i, reg_loss, val_r[-1]))

            self.training_stats.append({
                'training loss' : sum(train_loss)/len(train_loss),
                'validation loss' : sum(val_loss)/len(val_loss),
                'train r' : torch.tensor(train_r).float().mean(dim=0),
                'val r' : torch.tensor(val_r).float().mean(dim=0),
            })
            print(json.dumps(self.training_stats[-1], indent=4))
            
        def save_model(self) :
            #TODO: Add function to save model
            pass
        
        def plot(self, r_values) :
            #TODO: Plot r values
            pass

In [6]:
train = Train()
train.train(10)

Batch: 0 train Loss: 8.834592819213867 train_r: [0, 0]
Batch: 10 train Loss: nan train_r: [0, 0]
Batch: 20 train Loss: nan train_r: [0, 0]
Batch: 30 train Loss: nan train_r: [0, 0]


KeyboardInterrupt: 