In [1]:
import torch
import random
import torch.nn.functional as F


In [2]:
from convokit import Corpus, download
corpus = Corpus(filename=download("conversations-gone-awry-corpus"))

Dataset already exists at /Volumes/Users/tran_s2/.convokit/downloads/conversations-gone-awry-corpus


In [3]:
def loadPairs(corpus, split=None, last_only=False):
    """
    Load context-reply pairs from the Corpus, optionally filtering to only conversations
    from the specified split (train, val, or test).
    Each conversation, which has N comments (not including the section header) will
    get converted into N-1 comment-reply pairs, one pair for each reply
    (the first comment does not reply to anything).
    Each comment-reply pair is a tuple consisting of the conversational context
    (that is, all comments prior to the reply), the reply itself, the label (that
    is, whether the reply contained a derailment event), and the comment ID of the
    reply (for later use in re-joining with the ConvoKit corpus).
    The function returns a list of such pairs.
    """
    pairs = []
    count_attack = 0
    count_convo = 0
    for convo in corpus.iter_conversations():
        # consider only conversations in the specified split of the data
        if split is None or convo.meta['split'] == split:
            count_convo += 1
            utterance_list = []
            for utterance in convo.iter_utterances():
                if utterance.meta['is_section_header']:
                    continue
                if utterance.meta['comment_has_personal_attack']:
                    count_attack += 1
                utterance_list.append({"text": utterance.text, 
                                        "is_attack": int(utterance.meta['comment_has_personal_attack']), 
                                        "id": utterance.id})
                
            iter_range = range(1, len(utterance_list)) if not last_only else [len(utterance_list)-1]
            for idx in iter_range:
                reply = utterance_list[idx]["text"]
                label = utterance_list[idx]["is_attack"]
                comment_id = utterance_list[idx]["id"]
                # gather as context all utterances preceding the reply
                context = [u["text"] for u in utterance_list[:idx]]
                pairs.append((context, reply, label, comment_id))

    return pairs
def conversations2utterances(conversations):
    """
    Convert list of conversations into list of utterances for UtteranceModel.
    INPUT:
        conversations: list of list of str
            List of conversations, each conversation is a list of utterances.
    OUTPUT:
        utterances: list of str
            List of utterances in the dataset.
        conversationLength: list of int
            List of number of utterances in conversations.
    """
    conversationLength = [len(convo) for convo in conversations]
    utterances = []
    for convo in conversations:
        for utterance in convo:
            utterances.append(utterance)
    # assert len(utterances) == sum(conversationLength)
    return utterances, conversationLength
def load_data(corpus, context_batch_size = 32, split=None, last_only=False, shuffle=True):
    """
    Load data from corpus into the format ready for UtteranceModel.
    INPUT:
        corpus: convokit.Corpus
        split: str, optional
            If specified, only consider conversations in the specified split of the data.
        last_only: bool, optional
            If True, only consider the last utterance in each conversation.
    OUTPUT:
        utterances: list of str
            List of utterances in the dataset.
        conversationLength: list of int
            List of lengths of conversations in the dataset.
        comment_ids: list of str
            List of ids corresponding to the reply utterance.
        labels: list of int
            List of labels for each context if the next reply contains personal attack.
    """
    pairs = loadPairs(corpus, split, last_only)
    if shuffle:
        random.shuffle(pairs)
    batch_labels = []
    batch_comment_ids = []
    batch_utterances = []
    batch_conversationLength = []
    conversations = []
    labels = []
    comment_ids = []
    for pair in pairs:
        if len(labels) == context_batch_size:
            utterances, conversationLength = conversations2utterances(conversations)
            batch_utterances.append(utterances)
            batch_conversationLength.append(conversationLength)
            batch_labels.append(labels)
            batch_comment_ids.append(comment_ids)
            assert len(conversationLength) == len(comment_ids) == len(labels)
            conversations = []
            labels = []
            comment_ids = []

        context, _, label, comment_id = pair
        conversations.append(context)
        labels.append(label)
        comment_ids.append(comment_id)
    if len(conversations) > 0:
        utterances, conversationLength = conversations2utterances(conversations)
        batch_utterances.append(utterances)
        batch_conversationLength.append(conversationLength)
        batch_labels.append(labels)
        batch_comment_ids.append(comment_ids)
    return batch_utterances, batch_conversationLength, batch_comment_ids, batch_labels

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# BERT utterance encoder

In [5]:
from transformers import RobertaPreTrainedModel
from transformers import RobertaTokenizer
from transformers import RobertaModel
from torch import nn
class RoBERTaForUtterance(RobertaPreTrainedModel):
    def __init__(self, config, device, batch_size=4):
        super(RoBERTaForUtterance, self).__init__(config)
        self.batch_size = batch_size
        self.roberta = RobertaModel.from_pretrained("roberta-base").to(device)
        self.tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
        # self.attention1 = nn.Linear(768, 768)
        # self.attention2 = nn.Linear(768, 1, bias=False)
        # self.clf = nn.Linear(768, 1)
    def tokenize(self, utterances, conversationLength):
        # curr_utterance_idx = 0
        # first_col = []
        # second_col = []
        # for convo_len in conversationLength:
        #     first_col.append("Start")
        #     second_col.append(utterances[curr_utterance_idx])
        #     for i in range(curr_utterance_idx + 1, curr_utterance_idx + convo_len):
        #         first_col.append(utterances[i-1])
        #         second_col.append(utterances[i])
        #     curr_utterance_idx += convo_len

        # def batch_tokenize(first_col, second_col, batch_size=8):
        #     curr_idx = 0
        #     while curr_idx < len(first_col):
        #         tokens = self.tokenizer(first_col[curr_idx:curr_idx+batch_size], 
        #                         second_col[curr_idx:curr_idx+batch_size], 
        #                             padding="max_length", truncation='longest_first', 
        #                             max_length=512, return_tensors="pt")
        #         yield tokens
        #         curr_idx += batch_size
        def batch_tokenize(utterances, batch_size=8):
            curr_idx = 0
            while curr_idx < len(utterances):
                tokens = self.tokenizer(utterances[curr_idx:curr_idx+batch_size], 
                                padding="max_length", truncation=True, 
                                max_length=256, return_tensors="pt")
                yield tokens
                curr_idx += batch_size
        return batch_tokenize(utterances, batch_size=self.batch_size)

    def forward(self, utterances, conversationLength):
        cls_vectors = []
        for tokens in self.tokenize(utterances, conversationLength):
            outputs = self.roberta(tokens.input_ids.to(device),
                        attention_mask=tokens.attention_mask.to(device))
            hidden = outputs.pooler_output.cpu().detach()
            
            # cls = hidden[:, 0, :].squeeze(1)
            cls_vectors.append(hidden)
        return torch.cat(cls_vectors, dim=0)
        



# Context RNN

In [6]:
class ContextEncoderRNN(nn.Module):
    """This module represents the context encoder component of CRAFT, responsible for creating an order-sensitive vector representation of conversation context"""
    def __init__(self, hidden_size, n_layers=1, dropout=0):
        super(ContextEncoderRNN, self).__init__()
        self.n_layers = n_layers
        self.hidden_size = hidden_size

        # only unidirectional GRU for context encoding
        self.gru = nn.GRU(hidden_size, hidden_size, n_layers, batch_first = True,
                          dropout=(0 if n_layers == 1 else dropout), bidirectional=False)

    def forward(self, input_seq, hidden=None):
        # Pack padded batch of sequences for RNN module
        # packed = torch.nn.utils.rnn.pack_padded_sequence(input_seq, input_lengths)
        # Forward pass through GRU
        outputs, hidden = self.gru(input_seq, hidden)
        # Unpack padding
        # outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs)
        # return output and final hidden state
        return outputs, hidden

class SingleTargetClf(nn.Module):
    """This module represents the CRAFT classifier head, which takes the context encoding and uses it to make a forecast"""
    def __init__(self, hidden_size, dropout=0.1):
        super(SingleTargetClf, self).__init__()

        self.hidden_size = hidden_size

        # initialize classifier
        self.layer1 = nn.Linear(hidden_size, hidden_size)
        self.layer1_act = nn.LeakyReLU()
        self.layer2 = nn.Linear(hidden_size, hidden_size // 2)
        self.layer2_act = nn.LeakyReLU()
        self.clf = nn.Linear(hidden_size // 2, 1)
        self.dropout = nn.Dropout(p=dropout)
    def forward(self, encoder_hidden):
        hidden = encoder_hidden[-1,:,:]
        # forward pass through hidden layers
        hidden = hidden.squeeze()
        layer1_out = self.layer1_act(self.layer1(self.dropout(hidden)))
        layer2_out = self.layer2_act(self.layer2(self.dropout(layer1_out)))
        # compute and return logits
        logits = self.clf(self.dropout(layer2_out)).squeeze()
        return logits


In [7]:
class HierAtt(nn.Module):
    def __init__(self, utt_emb_size) -> None:
        super().__init__()
        self.attention1 = nn.Linear(utt_emb_size, utt_emb_size)
        self.attention2 = nn.Linear(utt_emb_size, 1, bias=False)
        self.clf = nn.Linear(utt_emb_size, 1)
    def forward(self, utt_emb, mask):
        # utt_emb: [batch_size, seq_length, utt_emb_size]
        conversation_embedding = self.attention_net(utt_emb, self.attention1, self.attention2, mask)
        final_output = self.clf(conversation_embedding)
        return final_output

    def attention_net(self, utt_emb, attention_net1, attention_net2 , mask):
        print(utt_emb.get_device())
        print(attention_net1.get_device())
        hidden_re = torch.tanh(attention_net1(utt_emb)) # [batch_size, seq_length, utt_emb_size]
        attn_weights = attention_net2(hidden_re).squeeze(2) # [batch_size, seq_length]
        attn_weights = attn_weights.masked_fill(mask==0, -1e15)
        soft_attn_weights = F.softmax(attn_weights, 1) # [batch_size, seq_length]
        # [batch_size, utt_emb_size, seq_length] * [batch_size, seq_length, 1] = [batch_size, utt_emb_size, 1]
        final_embedding = torch.bmm(utt_emb.transpose(1, 2), soft_attn_weights.unsqueeze(2)).squeeze(2)
        return final_embedding

# Pipeline

In [8]:
train_utterances, train_conversationLength, train_comment_ids, train_labels = load_data(corpus, split='train', last_only=True, context_batch_size=64)
valid_utterances, valid_conversationLength, valid_comment_ids, valid_labels = load_data(corpus, split='val', last_only=True, context_batch_size=64)

In [9]:
len(valid_labels[0])

64

In [10]:
sum(valid_labels[0])/len(valid_labels[0])

0.453125

In [11]:
from transformers import AutoConfig
config = AutoConfig.from_pretrained('roberta-base')
my_roberta = RoBERTaForUtterance(config, 'cuda', batch_size=16)
context_encoder = ContextEncoderRNN(768, 1, 0.1)
attack_clf = SingleTargetClf(768, 0.1)
context_encoder = context_encoder.to(device)
attack_clf = attack_clf.to(device)

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [12]:
import numpy as np
def prepare_context_batch(utt_hidden, batch_conversationLength, max_context_len=20):
    assert utt_hidden.shape[0] == sum(batch_conversationLength)
    # utt_encoder_summed = utt_hidden[-2,:,:] + utt_hidden[-1,:,:]
    hidden_size = utt_hidden.shape[1]
    context_features = np.zeros((len(batch_conversationLength), max_context_len, hidden_size), dtype=np.float32)

    current_utt_idx = 0
    for i, convo_len in enumerate(batch_conversationLength):
        if convo_len > max_context_len:
            current_utt_idx += convo_len - max_context_len
            convo_len = max_context_len
        context_features[i, -convo_len:, :] = np.array(utt_hidden)[current_utt_idx:current_utt_idx+convo_len, :]
        current_utt_idx += convo_len
    return context_features

In [13]:
from torch import optim
learning_rate = 3e-5
# pos_weight = torch.tensor([0.8]).to(device)
# criterion = nn.BCEWithLogitsLoss(reduction = 'sum')
# loss = loss_fct(logits, labels)
criterion = nn.BCEWithLogitsLoss()
# encoder_optimizer = optim.Adam(my_roberta.parameters(), lr=learning_rate)
encoder_optimizer = optim.AdamW(my_roberta.parameters(),
                  lr = learning_rate, # args.learning_rate - default is 5e-5,
                  eps = 1e-8 # args.adam_epsilon  - default is 1e-8.
                )
context_encoder_optimizer = optim.AdamW(context_encoder.parameters(), lr=learning_rate)
attack_clf_optimizer = optim.AdamW(attack_clf.parameters(), lr=learning_rate)

In [14]:
batch_utterances = train_utterances[0]
batch_conversationLength = train_conversationLength[0]
print(len(batch_utterances))

347


In [15]:
def calculate_f1_score(labels, preds):
    # preds = torch.sigmoid(logits) > 0.5
    # Calculating precision, recall, and F1 score using PyTorch
    TP = ((preds == 1) & (labels == 1)).sum().item()
    FP = ((preds == 1) & (labels == 0)).sum().item()
    FN = ((preds == 0) & (labels == 1)).sum().item()

    precision = TP / (TP + FP) if TP + FP > 0 else 0
    recall = TP / (TP + FN) if TP + FN > 0 else 0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    return f1
def evaluate(encoder, context_encoder, attack_clf, val_utterances, val_conversationLength, val_labels):
    encoder.eval()
    context_encoder.eval()
    attack_clf.eval()
    val_loss = 0
    val_f1 = 0
    num_sample = 0
    all_labels = []
    all_preds = []
    for i in range(len(val_utterances)):
        batch_utterances = val_utterances[i]
        batch_conversationLength = val_conversationLength[i]
        batch_labels = val_labels[i]
        batch_size = len(batch_labels)
        num_sample += batch_size
        if batch_size == 0:
            continue
        with torch.no_grad():
            utt_hidden = encoder.forward(batch_utterances, batch_conversationLength)
            context_features = prepare_context_batch(utt_hidden, batch_conversationLength)
            context_features = torch.from_numpy(context_features).to(device)
            context_outputs, context_hidden = context_encoder(context_features)
            logits = attack_clf(context_hidden)
            labels = torch.tensor(batch_labels, dtype=torch.float32).to(device)
            # pos_weight = torch.tensor([1]).type_as(logits)
            # loss_fct = BCEWithLogitsLoss(pos_weight=pos_weight, reduction = 'sum')
            # loss = loss_fct(logits, labels)
            loss = F.binary_cross_entropy_with_logits(logits, labels)
            val_loss += loss.item()
            preds = torch.sigmoid(logits) > 0.5
            # pos += (preds.sum().item() / batch_size)
            all_labels.append(labels.cpu().detach())
            all_preds.append(preds.cpu().detach())
            # val_accuracy += (preds == labels).sum().item() / batch_size
    all_labels = torch.cat(all_labels, dim=0)
    all_preds = torch.cat(all_preds, dim=0)
    # print(num_sample, all_labels.shape, all_preds.shape)
    assert all_labels.shape == all_preds.shape
    # assert all_labels.shape[0] == num_sample
    val_accuracy = (all_labels == all_preds).sum().item()/num_sample
    pos = all_preds.sum().item()/num_sample

    val_f1 = calculate_f1_score(all_labels, all_preds)
    return val_loss, val_f1, val_accuracy, pos

In [16]:
import logging
logging.disable(logging.WARNING)
num_steps = 0
for epoch in range(50):
    for batch_idx in range(len(train_labels)):
        num_steps += 1
        if num_steps % 200 == 0:
            print(num_steps, epoch)
            val_loss, val_f1, val_accuracy, pos= evaluate(my_roberta, context_encoder, attack_clf, valid_utterances, valid_conversationLength, valid_labels)
            print("Validation loss: {:.2f} accuracy: {:.2f} f1: {:.2f} pos: {:.2f}".format(val_loss, val_accuracy * 100, val_f1 * 100, pos))

        my_roberta.train()
        context_encoder.train()
        attack_clf.train()
        
        batch_utterances = train_utterances[batch_idx]
        batch_conversationLength = train_conversationLength[batch_idx]
        batch_comment_ids = train_comment_ids[batch_idx]
        batch_labels = train_labels[batch_idx]
        encoder_optimizer.zero_grad()
        context_encoder_optimizer.zero_grad()
        attack_clf_optimizer.zero_grad()
        
        hidden = my_roberta.forward(batch_utterances, batch_conversationLength)
        context_features = torch.from_numpy(prepare_context_batch(hidden, batch_conversationLength)).to(device)
        final_outputs, final_hidden = context_encoder.forward(context_features)
        logits = attack_clf(final_hidden)
        labels = torch.tensor(batch_labels, dtype=torch.float32).to(device)

        # loss = F.binary_cross_entropy_with_logits(logits, labels)
        loss = criterion(logits, labels)
        loss.backward()
        clip = 50.0
        # Clip gradients: gradients are modified in place
        _ = torch.nn.utils.clip_grad_norm_(my_roberta.parameters(), clip)
        _ = torch.nn.utils.clip_grad_norm_(context_encoder.parameters(), clip)
        _ = torch.nn.utils.clip_grad_norm_(attack_clf.parameters(), clip)

        # Adjust model weights
        encoder_optimizer.step()
        context_encoder_optimizer.step()
        attack_clf_optimizer.step()

200 4
Validation loss: 9.66 accuracy: 57.14 f1: 50.55 pos: 0.37
400 9
Validation loss: 9.59 accuracy: 54.76 f1: 39.10 pos: 0.24
600 14
Validation loss: 9.56 accuracy: 57.02 f1: 52.06 pos: 0.40
800 19
Validation loss: 9.55 accuracy: 57.02 f1: 52.06 pos: 0.40
1000 24
Validation loss: 9.56 accuracy: 57.02 f1: 52.06 pos: 0.40
1200 29
Validation loss: 9.54 accuracy: 57.02 f1: 52.06 pos: 0.40
1400 34
Validation loss: 9.54 accuracy: 56.79 f1: 51.66 pos: 0.39
1600 39
Validation loss: 9.55 accuracy: 57.14 f1: 51.87 pos: 0.39
1800 44
Validation loss: 9.55 accuracy: 57.14 f1: 51.87 pos: 0.39
2000 49
Validation loss: 9.55 accuracy: 57.02 f1: 51.41 pos: 0.38
