In [None]:
!pip install transformers==4.28.1
!pip install wandb

In [None]:
import logging
import random
import numpy as np
import torch
from torch import nn
from transformers import get_linear_schedule_with_warmup, AdamW, PreTrainedTokenizer
from collections import namedtuple
import json
from torch.utils.data import DataLoader, WeightedRandomSampler
#from model import BARTVAEClassifier, BARTDecoderClassifier, BARTVADVAEClassifier, RobertaClassifier
#from utils import ErcTextDataset, get_num_classes, get_label_VAD, convert_label_to_VAD, save_latent_params, compute_VAD_pearson_correlation, replace_for_robust_eval
import os
import math
import argparse
import yaml
from torch.nn.utils.rnn import pad_sequence
from sklearn.metrics import f1_score, confusion_matrix, accuracy_score, classification_report, \
    precision_recall_fscore_support, precision_score, recall_score
import torch.cuda.amp.grad_scaler as grad_scaler
import torch.cuda.amp.autocast_mode as autocast_mode
from transformers.models.bart.modeling_bart import BartModel, BartDecoder
import copy
import torch.nn.functional as F


from typing import List

from torch.utils.data import Dataset
from tqdm import tqdm
from torch.nn.functional import softplus

from transformers import RobertaTokenizer, RobertaModel, RobertaConfig, RobertaForMaskedLM, AutoModel, AutoTokenizer, AutoConfig, BartTokenizer, BartConfig,RobertaForSequenceClassification, BartForSequenceClassification
from transformers.models.bart.modeling_bart import BartLearnedPositionalEmbedding, BartEncoderLayer, BartDecoderLayer

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')
torch.cuda.empty_cache()
torch.manual_seed(42)


class EncodedDataset(Dataset):

  def __init__(self, input_sents: List[str], 
                input_labels: List[int], 
                target_labels:List[int], 
                encoder_tokenizer: PreTrainedTokenizer,
                decoder_tokenizer: PreTrainedTokenizer,
                max_sequence_length: int = None, 
                max_targets: int = 8):
      
    self.input_sents = input_sents
    self.input_labels = input_labels
    self.target_labels = target_labels
    self.encoder_tokenizer = encoder_tokenizer
    self.decoder_tokenizer = decoder_tokenizer
    self.max_sequence_length = max_sequence_length
    self.max_targets = max_targets

  def __len__(self):
    return len(self.input_sents) 

  def __getitem__(self, index):
    text = self.input_sents[index]
    label = self.input_labels[index]
    labels = np.zeros(2)
    labels[label] = 1
    target = self.target_labels[index]
    target_labels = np.zeros(8)
    #target_labels[target] = 1
    target_labels = torch.tensor(target_labels)
    labels = torch.tensor(labels)

    token = self.encoder_tokenizer(text, padding='max_length', max_length= self.max_sequence_length, truncation=True)

    input_ids, mask_ids = torch.tensor(token['input_ids']), torch.tensor(token['attention_mask'])

    if self.decoder_tokenizer.pad_token is None:
      self.decoder_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
    token = self.decoder_tokenizer(text, padding='max_length', max_length= self.max_sequence_length, truncation=True)

    dec_input_ids, dec_mask_ids = torch.tensor(token['input_ids']), torch.tensor(token['attention_mask'])

    return input_ids, mask_ids, target_labels, labels, dec_input_ids, dec_mask_ids

In [None]:
class _Inference(nn.Sequential):
    def __init__(self, num_input_channels, latent_dim, disc_variable=True):
        super(_Inference, self).__init__()
        if disc_variable:
            self.add_module('fc', nn.Linear(num_input_channels, num_input_channels//2))
            self.add_module('relu', nn.ReLU())
            self.add_module('fc2', nn.Linear(num_input_channels//2, latent_dim))
            self.add_module('log_softmax', nn.LogSoftmax(dim=1))
        else:
            self.add_module('fc', nn.Linear(num_input_channels, latent_dim))

class Sample(nn.Module):
    def __init__(self, temperature):
        super(Sample, self).__init__()
        self._temperature = temperature

    def forward(self, norm_mean, norm_log_sigma, target_sample, disc_label=None, mixup=False, disc_label_mixup=None,
                mixup_lam=None):
        """
        :param norm_mean: mean parameter of continuous norm variable
        :param norm_log_sigma: log sigma parameter of continuous norm variable
        :param disc_log_alpha: log alpha parameter of discrete multinomial variable
        :param disc_label: the ground truth label of discrete variable (not one-hot label)
        :param mixup: if we do mixup
        :param disc_label_mixup: the mixup target label
        :param mixup_lam: the mixup lambda
        :return: sampled latent variable
        """
        batch_size = norm_mean.size(0)
        latent_sample = list([])
        latent_sample.append(self._sample_norm(norm_mean, norm_log_sigma))
        latent_sample.append(target_sample)
        latent_sample = torch.cat(latent_sample, dim=1)
        dim_size = latent_sample.size(1)
        latent_sample = latent_sample.view(batch_size, dim_size, 1, 1)
        return latent_sample

    def _sample_gumbel_softmax(self, log_alpha):
        """
        Samples from a gumbel-softmax distribution using the reparameterization
        trick.

        Parameters
        ----------
        log_alpha : torch.Tensor
            Parameters of the gumbel-softmax distribution. Shape (N, D)
        """
        EPS = 1e-12
        unif = torch.rand(log_alpha.size()).cuda()
        gumbel = -torch.log(-torch.log(unif + EPS) + EPS)
        # Reparameterize to create gumbel softmax sample
        logit = (log_alpha + gumbel) / self._temperature
        return torch.softmax(logit, dim=1)

    @staticmethod
    def _sample_norm(mu, log_sigma):
        EPS = 1e-6
        std_z = torch.randn(mu.size()).to(mu.device)
        sigma = torch.exp(torch.clamp(log_sigma, max=10))  # Clamp to prevent overflow
        return mu + sigma * std_z
    # staticmethod
    # def _sample_norm(mu, log_sigma):
    #     """
    #     :param mu: the mu for sampling with N*D
    #     :param log_sigma: the log_sigma for sampling with N*D
    #     Return the latent normal sample z ~ N(mu, sigma^2)
    #     """
    #     std_z = torch.randn(mu.size())
    #     if mu.is_cuda:
    #         std_z = std_z.cuda()

    #     return mu + torch.exp(log_sigma) * std_z


class HATE_WATCH(nn.Module):
    """The VAD-VAE model."""
    def __init__(self, temperature,gamma,eta,enc_chk,check_point, dec_chk,bart_check_point,num_targets, num_class, beta_c,beta_d, device, batch_size, decoder_type, x_sigma=1):
        super(HATE_WATCH, self).__init__()
        self.device = device
        self.batch_size = batch_size
        self.decoder_type = decoder_type
        self.temperature = nn.Parameter(torch.tensor(temperature, requires_grad=True, device=device))

        #Prepare the encoder and decoder for VAE.
        if(enc_chk!= ""):
            self.encoder = RobertaForSequenceClassification.from_pretrained(check_point).to(device)
            self.encoder.load_state_dict(torch.load(enc_chk))
            self.encoder = self.encoder.roberta
        else:
            self.encoder = RobertaModel.from_pretrained(check_point).to(device)
        for param in self.encoder.base_model.parameters():
            param.requires_grad = False
        self.tokenizer = RobertaTokenizer.from_pretrained(check_point)
        self.config = AutoConfig.from_pretrained(check_point)

        self.cmi = torch.nn.Parameter(torch.rand(10, requires_grad=True, device=device))
        self.dmi = torch.nn.Parameter(torch.rand(10, requires_grad=True, device=device))

        hidden_size = self.config.hidden_size
        if decoder_type == 'BART':
            self.decoder = BartDecoder.from_pretrained(bart_check_point).to(device)
            # self.decoder = BartForSequenceClassification.from_pretrained(bart_check_point).to(device)
            # self.decoder.load_state_dict(torch.load(dec_chk))
            # self.decoder = self.decoder.model.decoder
        else:
            self.decoder = nn.LSTM(hidden_size, hidden_size, 1, batch_first=True)

        self.decoder_start_token_id = 2
        self.lm_head = nn.Linear(hidden_size, self.config.vocab_size)
        # self.target_labeler = FCLayer(hidden_size,num_targets,0.2)
        self.lm_loss_fn = nn.CrossEntropyLoss(reduction='none')
        self.x_sigma = x_sigma
        self.relu = nn.ReLU()
        self.gamma = nn.Parameter(torch.tensor(gamma, requires_grad=True, device=device))
        self.eta = nn.Parameter(torch.tensor(eta, requires_grad=True, device=device))
        self.compressor = nn.Linear(2*hidden_size,hidden_size)

        self.dropout_rate = params1.classifier_dropout

        #Prepare the disentanglement modules.
        self.continuous_inference = nn.Sequential()
        #self.disc_latent_inference = nn.Sequential()
        conti_mean_inf_module = _Inference(num_input_channels=hidden_size,
                                           latent_dim=hidden_size,
                                           disc_variable=False)
        conti_logsigma_inf_module = _Inference(num_input_channels=hidden_size,
                                               latent_dim=hidden_size,
                                               disc_variable=False)
        self.continuous_inference.add_module("mean", conti_mean_inf_module)
        self.continuous_inference.add_module("log_sigma", conti_logsigma_inf_module)

        # conti_mean_inf_module_t = _Inference(num_input_channels=hidden_size,
        #                                    latent_dim=hidden_size,
        #                                    disc_variable=False)
        # conti_logsigma_inf_module_t = _Inference(num_input_channels=hidden_size,
        #                                        latent_dim=hidden_size,
        #                                        disc_variable=False)
        # self.disc_latent_inference.add_module("mean", conti_mean_inf_module_t)
        # self.disc_latent_inference.add_module("log_sigma", conti_logsigma_inf_module_t)
        self._disc_latent_dim = num_targets
        # dic_inf = _Inference(num_input_channels=hidden_size, latent_dim=self._disc_latent_dim,
        #                      disc_variable=True)
        # self.disc_latent_inference = dic_inf
        sample = Sample(temperature=self.temperature)
        self.sample = sample

        self.kl_beta_c = beta_c
        self.kl_beta_d = beta_d

        self.disc_log_prior_param = torch.log(torch.tensor([1 / self._disc_latent_dim for i in range(self._disc_latent_dim)]).view(1, -1).float().cuda())


        #Reconstructor
        self.reconstructor = nn.Linear(hidden_size, hidden_size)
        self.batch_norm = nn.BatchNorm1d(num_features = hidden_size)
        #Hate Classifier
        self.hate_classifier = nn.Sequential(
            nn.Dropout(params1.classifier_dropout),
            nn.Linear(hidden_size,hidden_size),
            nn.Tanh(),
            nn.Dropout(params1.classifier_dropout),
            nn.Linear(hidden_size, num_class),
            nn.Softmax(dim=1)
        )

        #Target Mapper
        self.target_mapper = nn.Sequential(
            nn.Dropout(params1.classifier_dropout),
            nn.Linear(hidden_size,hidden_size),
            nn.Tanh(),
            nn.Dropout(params1.classifier_dropout),
            nn.Linear(hidden_size, hidden_size)
        )

        #Target Classifier
        self.target_classifier = nn.Sequential(
            nn.Dropout(params1.classifier_dropout),
            nn.Linear(hidden_size,hidden_size),
            nn.Tanh(),
            # nn.Dropout(params1.classifier_dropout),
            # nn.Linear(hidden_size, num_targets)
        )
        self.mse_loss = nn.MSELoss(reduction='sum')
        self.softmax = nn.Softmax(dim=1)

    def compute_sample_weight(self,y_pred):
      y_pred = y_pred.float()
      C = y_pred.shape[1]

      y_pred_clamped = torch.clamp(y_pred, 1e-9, 1.0)
      
      entropy = -torch.sum(y_pred_clamped * torch.log(y_pred_clamped), dim=1)
      max_entropy = torch.log(torch.tensor(C).float())
      weight = 1 - (entropy / max_entropy)
      
      return weight

    def get_contrastive_loss(self, embeddings, labels):
      sf = nn.Softmax(dim=1)
      softmax_labels = sf(labels)
      weights = self.compute_sample_weight(softmax_labels)
      max_indices = torch.argmax(softmax_labels, dim=1)
      weights = (max_indices.unsqueeze(1) == max_indices.unsqueeze(0)).float()

      tensor_expanded = embeddings.unsqueeze(1) - embeddings.unsqueeze(0)
      distance_matrix = torch.sqrt(torch.sum(tensor_expanded ** 2, dim=-1) + 1e-8) / (embeddings.shape[1] + 1e-8)
      mask = weights > self.eta
      contrastive_loss = weights * mask * distance_matrix.pow(2) + (1 - weights) * mask * F.relu(self.gamma - distance_matrix).pow(2)
      # Sum over all dimensions and average over the batch
      loss = torch.sum(contrastive_loss) / (mask.sum() + 1e-8)
      return loss,weights
      # loss = torch.mean(weights * distance_matrix.pow(2) + (1 - weights) * torch.clamp_min(self.gamma - distance_matrix, 0).pow(2), axis=1)
      # return torch.sum(loss)


    def get_lm_loss(self, logits, labels, masks):
        '''Get the utterance reconstruction loss.'''
        # labels = labels.float()
        loss = self.lm_loss_fn(logits.view(-1, self.config.vocab_size), labels.view(-1))
        masked_loss = loss * masks.view(-1)
        return torch.mean(masked_loss)
    
    def confidence_regularization(self,predictions):
      C = predictions.shape[1]
      uniform_distribution = torch.full_like(predictions, 1.0 / C)
      
      # Compute the KL divergence loss
      kl_div = F.kl_div(predictions.log(), uniform_distribution, reduction='batchmean')
      
      return kl_div


    def forward(self, inputs, mask, decoder_inputs, decoder_masks, decoder_labels, vad_labels, labels):
        x = self.encoder(inputs, attention_mask=mask)[0]
        x = x[:, 0, :].squeeze(1)
        x = self.batch_norm(x)
        x = (x - x.mean(dim=1, keepdim=True)) / (x.std(dim=1, keepdim=True) + 1e-6)

        if torch.isnan(x).any():
          print("NaN values detected in BatchNorm output")

        #Get the latent variables.
        norm_mean = self.continuous_inference.mean(x)
        norm_log_sigma = self.continuous_inference.log_sigma(x)
        # norm_mean_t = self.disc_latent_inference.mean(x)
        # norm_log_sigma_t = self.disc_latent_inference.log_sigma(x)
        t_sample = self.target_mapper(x)
        latent_sample = self.sample(norm_mean, norm_log_sigma,t_sample)
        t_logits = self.target_classifier(t_sample)
        #target_weak_labels = self.softmax(t_logits)
        # print(t_logits)
        con_loss, weights = self.get_contrastive_loss(t_sample,t_logits)

        # print(con_loss)

        latent_sample = torch.squeeze(latent_sample)
        latent_sample = self.compressor(latent_sample)
        decoder_hidden = self.reconstructor(latent_sample)

        if self.decoder_type == 'BART':
            decoder_outputs = self.decoder(
                input_ids=decoder_inputs,
                attention_mask=decoder_masks,
                encoder_hidden_states=decoder_hidden.unsqueeze(1))
            lm_logits = self.lm_head(decoder_outputs.last_hidden_state)
        else:
            input_embeddings = self.encoder.embeddings(decoder_inputs)
            h = decoder_hidden.unsqueeze(0)
            decoder_outputs, (_, _) = self.decoder(input_embeddings, (h, torch.zeros(h.shape).to(self.device)))
            lm_logits = self.lm_head(decoder_outputs)

        reconstruct_loss = self.get_lm_loss(lm_logits, decoder_labels, decoder_masks)#/ (2 * self.batch_size * (self.x_sigma ** 2))


        # calculate latent space KL divergence
        z_mean_sq = norm_mean * norm_mean
        z_log_sigma_sq = 2 * norm_log_sigma
        z_sigma_sq = torch.exp(z_log_sigma_sq)
        continuous_kl_loss = 0.5 * torch.sum(z_mean_sq + z_sigma_sq - z_log_sigma_sq - 1) / self.batch_size
        # notice here we duplicate the 0.5 by each part
        # disc param : log(a1),...,log(an) type
        mask = weights > self.eta
        log_q = torch.log(torch.tensor(1.0 / self._disc_latent_dim)).cuda()
        kl_div = torch.sum(torch.exp(t_logits) * (t_logits - log_q), dim=1)  # Sum over the dimension representing different classes
        weighted_kl_div = weights * mask * kl_div
        disc_kl_loss = torch.mean(weighted_kl_div)
        #disc_kl_loss = torch.mean(weights * mask * torch.sum(torch.exp(t_logits) * (t_logits - self.disc_log_prior_param)) / self.batch_size)

        prior_kl_loss_l = self.kl_beta_c * torch.abs(continuous_kl_loss - self.cmi) + self.kl_beta_d * torch.abs(disc_kl_loss - self.dmi)
        elbo_loss_l = reconstruct_loss + prior_kl_loss_l

        confidence_loss = self.confidence_regularization(self.softmax(t_logits))

        #Calculate the classification loss.
        hate_sample = self.sample._sample_norm(norm_mean, norm_log_sigma)
        #target_sample = self.sample._sample_gumbel_softmax(disc_log_alpha)

        # hate_sample = F.dropout(hate_sample, p=self.dropout_rate, training=self.training)
        # target_sample = self.get_vad_loss(target_sample,vad_labels)
        hate_logits = self.hate_classifier(hate_sample)

        # target_labels = torch.zeros_like(vad_labels).float().cuda()
        # target_labels[target_logits] = 1

        # print("Reconstruction loss",reconstruct_loss,"KL Continous Loss",self.kl_beta_c * torch.abs(continuous_kl_loss - self.cmi),"KL Discrete loss",self.kl_beta_d * torch.abs(disc_kl_loss - self.dmi),end=",")
        return elbo_loss_l[0], hate_logits, con_loss,confidence_loss

In [None]:
def train(hate_model,train_data, train_labels, train_target, val_data, val_labels, val_target, encoder_tokenizer, decoder_tokenizer, params1):
    accumulation_steps = 16
    train = EncodedDataset(input_sents=train_data,
                    input_labels=train_labels,
                    target_labels=train_target,
                    encoder_tokenizer=encoder_tokenizer,
                    decoder_tokenizer=decoder_tokenizer,
                    max_sequence_length=params1.max_sequence_length)

    val =  EncodedDataset(input_sents=val_data,
                    input_labels=val_labels,
                    target_labels=val_target,
                    encoder_tokenizer=encoder_tokenizer,
                    decoder_tokenizer=decoder_tokenizer,
                    max_sequence_length=params1.max_sequence_length)


    sampler = WeightedRandomSampler(params1.h_weights, len(params1.h_weights))

    train_dataloader = DataLoader(train, batch_size=params1.train_batch_size,drop_last=True,sampler=sampler)
    val_dataloader = DataLoader(val, batch_size=params1.val_batch_size,drop_last=True)

    vae_optimizer = torch.optim.AdamW(hate_model.parameters(),lr=params1.content_lr,weight_decay=params1.weight_decay)
    # hate_optimizer = torch.optim.AdamW(hate_model.hate_classifier.parameters(),lr=params1.content_lr,weight_decay=params1.weight_decay)
    # target_optimizer = torch.optim.AdamW(hate_model.target_classifier.parameters(),lr=params1.content_lr,weight_decay=params1.weight_decay)
    hate_loss = nn.CrossEntropyLoss(params1.hate_class_weights)
    #target_loss = nn.CrossEntropyLoss(params1.target_class_weights)

    s_total_steps = float(10 * len(train)) / params1.train_batch_size
    v_scheduler = get_linear_schedule_with_warmup(vae_optimizer, int(s_total_steps * params1.warmup_ratio),math.ceil(s_total_steps))

    # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,mode='min',factor=0.9,patience=2)

    # hate_opt = torch.optim.Adam(hate_model.parameters(), lr=params1.content_lr, weight_decay = 1e-2)
    save_dir = "D:/Hate Speech/Models/Disentanglement/WeakSupervision/"

    best_validation_accuracy = 1e-5
    print("Training started!")

    e=0
    for epoch in range(params1.num_epochs):
        total_vae_loss = 0
        total_adversary_loss = 0
        total_cls_loss = 0
        total_acc_train = 0
        total_acc_target = 0
        predictions = []
        y_true = []
        loss_list = []
        predicts = []
        ground_truth = []
        c=0
        cnt=0
        hate_model.train()
        for train_input, train_mask, train_target, train_label,train_dec_input, train_dec_mask in train_dataloader:
            hate_model.zero_grad()
            c+=1
            cnt+=1
            train_input = train_input.to(device)
            train_mask = train_mask.to(device)
            train_target = train_target.to(device)
            train_label = train_label.to(device)
            train_dec_input = train_dec_input.to(device)
            train_dec_mask = train_dec_mask.to(device)
            elbo_loss_l, hate_logits, con_loss, confidence_loss = hate_model(train_input, train_mask, train_dec_input, train_dec_mask, train_input,train_target, train_label)
            h_loss = hate_loss(hate_logits, train_label)
            #vad_loss = target_loss(target_logits,torch.argmax(train_target,dim=1))
            # vad_loss = target_loss(target_logits, train_target)
            # # Regularization losses
            # l2_strength = 1e-4
            # l2_loss = torch.tensor(0., requires_grad=True)
            # for name, param in hate_model.named_parameters():
            #     if 'weight' in name:
            #         l2_loss = l2_loss + torch.norm(param, p=2)
            # loss = h_loss + params1.alpha*elbo_loss_l + l2_strength * l2_loss + params1.target_loss_coeff*vad_loss
            loss = params1.hate_coeff*h_loss + params1.alpha*elbo_loss_l + params1.target_loss_coeff*con_loss + params1.target_loss_coeff*confidence_loss
            # loss = vad_loss
            if((c+1)%accumulation_steps==0):
                loss.backward()
                # for name, param in hate_model.named_parameters():
                #   if param.grad is not None:
                #       print(f"Layer: {name}, Gradient norm: {param.grad.data.norm(2)}")

                torch.nn.utils.clip_grad_norm_(hate_model.parameters(), 1.0)
                vae_optimizer.step()
                # hate_optimizer.step()
                # target_optimizer.step()
                v_scheduler.step()
                # h_scheduler.step()
                # t_scheduler.step()

            if(c%100==0):
                print("Temperature",hate_model.temperature.item(),"batch",c,"epoch",epoch,"1s",len(np.where(torch.argmax(hate_logits, dim=1).cpu().numpy()==1)[0]),"loss",loss.item(),"h_loss",h_loss.item(),"elbo_loss",params1.alpha*elbo_loss_l.item(),"vad_loss",params1.target_loss_coeff*con_loss.item())
            ground_truth += train_label.cpu().numpy().tolist()
            predicts += torch.argmax(hate_logits, dim=1).cpu().numpy().tolist()
            loss_list.append(loss.item())
            acc = round(accuracy_score(torch.argmax(train_label, dim=1).cpu().numpy().tolist(), torch.argmax(hate_logits, dim=1).cpu().numpy().tolist()) * 100, 3)
            #target_acc = round(accuracy_score(torch.argmax(train_target, dim=1).cpu().numpy().tolist(), torch.argmax(target_logits, dim=1).cpu().numpy().tolist()) * 100, 3)
            total_vae_loss += elbo_loss_l.item()
            #total_adversary_loss += vad_loss.item()
            total_cls_loss += h_loss.item()
            total_acc_train += acc
            #total_acc_target += target_acc
            # print("PREDS: ",torch.argmax(target_logits, dim=1).cpu().numpy().tolist())
            # print("TARGETS: ",torch.argmax(train_target, dim=1).cpu().numpy().tolist())
            if(c%100==0):
                print("Train Accuracy",acc)# Train Target Accuracy",target_acc)
            # if(c==500):
            #     break
        y_true = []
        predictions = []
        hate_model.eval()
        total_loss_val = 0
        total_acc_val = 0
        total_acc_target_val = 0
        e_cnt=0
        print("VALIDATION")
        with torch.no_grad():
            for val_input, val_mask, val_target, val_label, val_dec_input, val_dec_mask in val_dataloader:
                e_cnt+=1
                val_input = val_input.to(device)
                val_mask = val_mask.to(device)
                val_target = val_target.to(device)
                val_label = val_label.to(device)
                val_dec_input = val_dec_input.to(device)
                val_dec_mask = val_dec_mask.to(device)
                elbo_loss_l, hate_logits, target_logits, loss2 = hate_model(val_input, val_mask, val_dec_input, val_dec_mask, val_input,val_target, val_label)
                total_acc_val += round(accuracy_score(torch.argmax(val_label, dim=1).cpu().numpy().tolist(), torch.argmax(hate_logits, dim=1).cpu().numpy().tolist()) * 100, 3)
                #total_acc_target_val += round(accuracy_score(torch.argmax(val_target, dim=1).cpu().numpy().tolist(), torch.argmax(target_logits,dim=1).cpu().numpy().tolist()) * 100, 3)
                y_true += torch.argmax(val_label, dim=1).cpu().numpy().tolist()
                predictions += torch.argmax(hate_logits, dim=1).cpu().numpy().tolist()
            print(classification_report(y_true, predictions))

        metrics = {
            "Epoch": epoch,
            "Train ELBO Loss" : total_vae_loss/cnt,
            "Train Target Loss" : total_adversary_loss/cnt,
            "Train Hate Loss" : total_cls_loss/cnt,
            "Train Hate Accuracy" : total_acc_train/cnt,
            "Train Target Accuracy" : total_acc_target/cnt,
            "Validation Loss" : total_loss_val/e_cnt,
            "Validation Hate Accuracy" : total_acc_val/e_cnt,
            "Validation Target Accuracy" : total_acc_target_val/e_cnt}

        print(metrics)
        wandb.log(metrics)

        val_metrics = {"F1-Score": classification_report(y_true, predictions,output_dict=True)['macro avg']['f1-score']}
        wandb.log({**metrics, **val_metrics})

        if(best_validation_accuracy <= round(classification_report(y_true, predictions,output_dict=True)['macro avg']['f1-score'],3)):
            best_validation_accuracy = round(classification_report(y_true, predictions,output_dict=True)['macro avg']['f1-score'],3)
            best_report = classification_report(y_true, predictions)
            e=0
            print("E from if = ",e)
            # fname = "best-model_" + params1.dataset_name+"_"+str(epoch+1)+"_VAE_with_rob_base_NoFT_balancedData.pt"
            fname = "best-model_" + params1.dataset_name+"_"+str(epoch+1)+params1.fname
            torch.save(hate_model.state_dict(), os.path.join(save_dir, fname))
            print("Saved at ",os.path.join(save_dir, fname))
        else:
            print("E = ",e)
            e+=1
        if(e==3):
            print(best_report)
            break

In [None]:
import wandb
import pandas as pd
from sklearn.utils.class_weight import compute_class_weight
from types import SimpleNamespace


torch.manual_seed(42)

# files = ['D:/Hate Speech/Preprocessed Datasets/GAB/gab_HX_Numeric_train.csv',
#            'D:/Hate Speech/Preprocessed Datasets/Reddit/reddit_TRY_Numeric_train.csv',
#            'D:/Hate Speech/Preprocessed Datasets/Twitter/twitter_HX_Numeric_train.csv',
#         #    './Preprocessed Datasets/Twitter/twitter_TRY_Numeric_train.csv',
#         #    './Preprocessed Datasets/Twitter/twitter_Numeric_train.csv',
#         #     './Preprocessed Datasets/Youtube/youtube_Numeric_train.csv',
#         #     './Preprocessed Datasets/Youtube/youtube_TRY_Numeric_train.csv',
#             'D:/Hate Speech/Preprocessed Datasets/YouTube/youtube_IC_Numeric_balanced.csv']
# files = ["D:/Hate Speech/Preprocessed Datasets/RedditNew/RedditNew_Train.csv",
#               "D:/Hate Speech/Preprocessed Datasets/TwitterNew/TwitterNew_Train.csv"]

files = ['D:/Hate Speech/Preprocessed Datasets/YouTube/youtube_IC_Numeric_balanced.csv']

test_files = ["D:/Hate Speech/Preprocessed Datasets/RedditNew/RedditNew_Test.csv",
              "D:/Hate Speech/Preprocessed Datasets/TwitterNew/TwitterNew_Test.csv"]#'D:/Hate Speech/Preprocessed Datasets/GAB/gab_HX_Numeric_test.csv',
        #  'D:/Hate Speech/Preprocessed Datasets/Reddit/reddit_TRY_Numeric_test.csv',
        #   'D:/Hate Speech/Preprocessed Datasets/Twitter/twitter_HX_Numeric_test.csv',
        #    './Preprocessed Datasets/Twitter/twitter_TRY_Numeric_test.csv',
        #    './Preprocessed Datasets/Twitter/twitter_Numeric_test.csv',
        #     './Preprocessed Datasets/Youtube/youtube_Numeric_test.csv',
        #     './Preprocessed Datasets/Youtube/youtube_TRY_Numeric_test.csv',
        #    'D:/Hate Speech/Preprocessed Datasets/Youtube/youtube_IC_Numeric_test.csv']

enc_files = [#'D:/Hate Speech/Finetuned/Roberta/gab_HX/pytorch_model.bin',
            'D:/Hate Speech/Encoder_FT/Roberta/RedditNew/pytorch_model.bin',
            'D:/Hate Speech/Encoder_FT/Roberta/TwitterNew/pytorch_model.bin']
            # './FineTuned/Roberta/twitter_TRY/pytorch_model.bin',
            # './FineTuned/Roberta/twitter/pytorch_model.bin',
            # './FineTuned/Roberta/youtube/pytorch_model.bin',
            # './FineTuned/Roberta/youtube_TRY/pytorch_model.bin',
            #'D:/Hate Speech/Finetuned/Roberta/youtube_IC/pytorch_model.bin']

# enc_files = ['./FineTuned/Roberta/gab/pytorch_model.bin',
#              './FineTuned/RobertaHate/reddit_TRY/pytorch_model.bin',
#              './FineTuned/RobertaHate/twitter/pytorch_model.bin',
#              './FineTuned/RobertaHate/youtube/pytorch_model.bin']

dec_files = [#'D:/Hate Speech/Finetuned/BART/pytorch_model.bin',
             'D:/Hate Speech/Decoder_FT/BART/RedditNew/pytorch_model.bin',
             'D:/Hate Speech/Decoder_FT/BART/TwitterNew/pytorch_model.bin']
             #'D:/Hate Speech/Finetuned/BART_youtube/pytorch_model.bin']


dataset_names = ["Reddit_New","Twitter_New"]#["gab_HX","reddit_TRY","twitter_HX","youtube_IC"]#"twitter_TRY","twitter","youtube","youtube_TRY",
hidden_size = 768
classifier_dropout = 0.2
learning_rate = 5e-3
print(files)
print(len(files))
print(device)
filenames = set()
latent_variables = ['hate','target']#['0','1','2','3','4','5','6','7']

for f in range(0,len(files)):
        torch.cuda.empty_cache()
        train_frame = pd.read_csv(files[f])
        tar = np.zeros(8)
        #a1 = np.unique(train_frame['target'])
        #class_weights = compute_class_weight('balanced', classes=a1, y=train_frame['target'])
        # for i in range(len(a1)):
        #     try:
        #         tar[a1[i]] = class_weights[a1[i]]
        #     except:
        #         tar[a1[i]] = 0
        # print(tar)
        # class_weights2 = torch.FloatTensor(tar).to(device)
        #print(class_weights2)
        if(dataset_names[f] not in filenames):
            filenames.add(dataset_names[f])
            # train_frame = pd.read_csv(files[f])
            if(f+1<len(files)):
                print("TEST FILE: ",test_files[f+1])
                test_frame = pd.read_csv(test_files[f+1])
            else:
                 test_frame = pd.read_csv(test_files[0])
            class_weights1 = compute_class_weight('balanced', classes=np.unique(train_frame['label']), y=train_frame['label'])
            class_weights1 = torch.FloatTensor(class_weights1)
            x = train_frame['label'].value_counts().values
            class_weights3 = torch.FloatTensor([x[0]/sum(x),x[1]/sum(x)])
            # class_weights3 = torch.FloatTensor([0.6,0.4])
            h_weights = class_weights3[train_frame['label']]
            print(h_weights)
            class_weights1 = class_weights1.to(device)
            wandb.init(
                # Set the project where this run will be logged
                project="Disentangling Hate Speech and Target", 
                # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
                name=f"experiment_ROB_BART_{dataset_names[f]}_{classifier_dropout}_{learning_rate}",
                # Track hyperparameters and run metadata
                config = { 
                        "max_sequence_length": 512, 
                        "train_batch_size" : 8, 
                        "val_batch_size" : 8,
                        "hidden_dim" : 512, 
                        "hate_dim" : 384,
                        "num_epochs" : 100, 
                        "device" : device,
                        "dataset_name" : dataset_names[f],
                        "h_weights" : h_weights,
                        "hate_class_weights" : class_weights1,
                        #"target_class_weights" : class_weights2,
                        "hidden_size" : 768,
                        "num_labels": 2,
                        "num_targets": 8,
                        "classifier_dropout" : classifier_dropout,
                        "content_lr": 2e-5,
                        "decoder_type" : "BART",
                        "kl_weight" : 0.05,
                        "mi_loss_weight" : 0.001,
                        "mi_loss" : False,
                        "alpha" : 1,
                        "beta_c" : 0.05,
                        "beta_d" : 0.05,
                        "warmup_ratio" : 0.2,
                        "weight_decay" : 0.001,
                        "target_loss_coeff" : 1,
                        "hate_coeff" : 1,
                        "fname" : "_VAE_with_rob_base_WS_OURS.pt",
            })
            params1 = { 
                        "max_sequence_length": 512, 
                        "train_batch_size" : 8, 
                        "val_batch_size" : 8,
                        "hidden_dim" : 512, 
                        "hate_dim" : 384,
                        "num_epochs" : 100, 
                        "device" : device,
                        "dataset_name" : dataset_names[f],
                        "h_weights" : h_weights,
                        "hate_class_weights" : class_weights1,
                        #"target_class_weights" : class_weights2,
                        "hidden_size" : 768,
                        "num_labels": 2,
                        "num_targets": 8,
                        "classifier_dropout" : classifier_dropout,
                        "content_lr": 2e-5,
                        "decoder_type" : "BART",
                        "kl_weight" : 0.05,
                        "mi_loss_weight" : 0.001,
                        "mi_loss" : False,
                        "alpha" : 1,
                        "beta_c" : 0.05,
                        "beta_d" : 0.05,
                        "warmup_ratio" : 0.2,
                        "weight_decay" : 0.001,
                        "target_loss_coeff" : 1,
                        "hate_coeff" : 1,
                        "fname" : "_VAE_with_rob_base_WS_OURS.pt",
            }
            params1 = SimpleNamespace(**params1)
            # enc_model_name = "facebook/roberta-hate-speech-dynabench-r4-target"
            enc_model_name = "roberta-base"
            dec_model_name = "facebook/bart-base"
            # enc_chk = "./FineTuned/Roberta/pytorch_model.bin"
            print(enc_files[f])
            print(params1.fname)
            enc_chk = enc_files[f]#"./FineTuned/RobertaHate/pytorch_model.bin"
            dec_chk = dec_files[f]#"./FineTuned/BART/pytorch_model.bin"
            hate_model = HATE_WATCH(1.0,0.1,0.95,enc_chk,enc_model_name,dec_chk,dec_model_name, params1.num_targets,params1.num_labels,params1.beta_c, params1.beta_d,device,params1.train_batch_size,params1.decoder_type).to(device)
            encoder_tokenizer = RobertaTokenizer.from_pretrained(enc_model_name)
            decoder_tokenizer = BartTokenizer.from_pretrained(dec_model_name)
            print(params1.dataset_name)
            print(files[f])
            print(train_frame.shape)
            wandb.watch(hate_model)
            train(hate_model=hate_model,
            train_data=train_frame['text'].values.tolist(), 
            train_labels=train_frame['label'].values.tolist(), 
            train_target=[""]*len(train_frame['label'].values.tolist()), 
            val_data=test_frame['text'].values.tolist(), 
            val_labels=test_frame['label'].values.tolist(), 
            val_target = [""]*len(train_frame['label'].values.tolist()),
            encoder_tokenizer=encoder_tokenizer,
            decoder_tokenizer=decoder_tokenizer,
            params1=params1)

In [None]:
def evaluate(model, test_data, test_labels, target_labels, encoder_tokenizer, decoder_tokenizer, params1):
    
    test = EncodedDataset(input_sents=test_data, 
                    input_labels=test_labels, 
                    target_labels=target_labels,  
                    encoder_tokenizer=encoder_tokenizer,
                    decoder_tokenizer=decoder_tokenizer,
                    max_sequence_length=params1.max_sequence_length)
    

    val_dataloader = DataLoader(test, batch_size=params1.val_batch_size, shuffle=False, drop_last=True)

    predictions = []
    y_true = []
    total_acc_val = 0
    total_acc_target_val = 0
    hate_model.eval()
    e_cnt=0
    with torch.no_grad():
      for val_input, val_mask, val_target, val_label, val_dec_input, val_dec_mask in val_dataloader:
        e_cnt+=1
        val_input = val_input.to(device)
        val_mask = val_mask.to(device)
        val_target = val_target.to(device)
        val_label = val_label.to(device)
        val_dec_input = val_dec_input.to(device)
        val_dec_mask = val_dec_mask.to(device)
        elbo_loss_l, hate_logits, target_logits, _ = hate_model(val_input, val_mask, val_dec_input, val_dec_mask, val_input,val_target, val_label)
        total_acc_val += round(accuracy_score(torch.argmax(val_label, dim=1).cpu().numpy().tolist(), torch.argmax(hate_logits, dim=1).cpu().numpy().tolist()) * 100, 3)
        #total_acc_target_val += round(accuracy_score(torch.argmax(val_target, dim=1).cpu().numpy().tolist(), torch.argmax(target_logits,dim=1).cpu().numpy().tolist()) * 100, 3)
        y_true += torch.argmax(val_label, dim=1).cpu().numpy().tolist()
        predictions += torch.argmax(hate_logits, dim=1).cpu().numpy().tolist()
      print(classification_report(y_true, predictions))
      print("Accuracy: ", total_acc_val/e_cnt)
      #print("Target Accuracy: ", total_acc_target_val/e_cnt)

In [None]:
# model_files = [#'./BuildingFramework/Models/Disentanglement/best-model_gab_HX_1_VAE_with_rob_base_ft.pt',
#                #'./BuildingFramework/Models/Disentanglement/best-model_gab_HX_2_VAE_with_rob_base_ft.pt',
#                #'./BuildingFramework/Models/Disentanglement/best-model_gab_HX_3_VAE_with_rob_base_ft.pt',
#                #'./BuildingFramework/Models/Disentanglement/best-model_gab_HX_4_VAE_with_rob_base_ft.pt',
#                './BuildingFramework/Models/Disentanglement/best-model_gab_HX_13_VAE_with_rob_base_ft_balancedData.pt',
#                #'./BuildingFramework/Models/Disentanglement/best-model_gab_HX_19_VAE_with_rob_base_NoFT_balancedData.pt',
#                './BuildingFramework/Models/Disentanglement/best-model_gab_HX_6_VAE_with_rob_base_NoHL_balancedData.pt',
#                './BuildingFramework/Models/Disentanglement/best-model_gab_HX_6_VAE_with_rob_base_NoBL_balancedData.pt',
#                #'./BuildingFramework/Models/Disentanglement/best-model_reddit_TRY_1_VAE_with_rob_base_ft.pt',
#                './BuildingFramework/Models/Disentanglement/best-model_reddit_TRY_2_VAE_with_rob_base_ft_balancedData.pt',
#                #'./BuildingFramework/Models/Disentanglement/best-model_reddit_TRY_3_VAE_with_rob_base_NoFT_balancedData.pt',
#                './BuildingFramework/Models/Disentanglement/best-model_reddit_TRY_1_VAE_with_rob_base_NoHL_balancedData.pt',
#                './BuildingFramework/Models/Disentanglement/best-model_reddit_TRY_3_VAE_with_rob_base_NoTL_balancedData.pt',
#                './BuildingFramework/Models/Disentanglement/best-model_reddit_TRY_1_VAE_with_rob_base_NoBL_balancedData.pt',
#                #'./BuildingFramework/Models/Disentanglement/best-model_Twitter_HX_1_VAE_with_rob_base_ft.pt',
#                #'./BuildingFramework/Models/Disentanglement/best-model_Twitter_HX_2_VAE_with_rob_base_ft.pt',
#                #'./BuildingFramework/Models/Disentanglement/best-model_Twitter_HX_3_VAE_with_rob_base_ft.pt',
#                #'./BuildingFramework/Models/Disentanglement/best-model_Twitter_HX_4_VAE_with_rob_base_ft.pt',
#                #'./BuildingFramework/Models/Disentanglement/best-model_Twitter_HX_5_VAE_with_rob_base_ft.pt',
#                #'./BuildingFramework/Models/Disentanglement/best-model_Twitter_HX_6_VAE_with_rob_base_ft.pt',
#                 './BuildingFramework/Models/Disentanglement/best-model_twitter_HX_6_VAE_with_rob_base_ft_balancedData.pt',
#               # './BuildingFramework/Models/Disentanglement/best-model_twitter_HX_7_VAE_with_rob_base_NoFT_balancedData.pt',
#                 './BuildingFramework/Models/Disentanglement/best-model_twitter_HX_3_VAE_with_rob_base_NoHL_balancedData.pt',
#                 './BuildingFramework/Models/Disentanglement/best-model_twitter_HX_11_VAE_with_rob_base_NoTL_balancedData.pt',
#                 './BuildingFramework/Models/Disentanglement/best-model_twitter_HX_3_VAE_with_rob_base_NoBL_balancedData.pt',
#                #'./BuildingFramework/Models/Disentanglement/best-model_Twitter_TRY_1_VAE_with_rob_base_ft.pt',
#                #'./BuildingFramework/Models/Disentanglement/best-model_Twitter_TRY_2_VAE_with_rob_base_ft.pt',
#                #'./BuildingFramework/Models/Disentanglement/best-model_twitter_1_VAE_with_rob_base_ft.pt',
#                #'./BuildingFramework/Models/Disentanglement/best-model_twitter_2_VAE_with_rob_base_ft.pt',
#                #'./BuildingFramework/Models/Disentanglement/best-model_youtube_1_VAE_with_rob_base_ft.pt',
#                #'./BuildingFramework/Models/Disentanglement/best-model_youtube_2_VAE_with_rob_base_ft.pt',
#                #'./BuildingFramework/Models/Disentanglement/best-model_youtube_TRY_1_VAE_with_rob_base_ft.pt',
#                #'./BuildingFramework/Models/Disentanglement/best-model_youtube_IC_1_VAE_with_rob_base_ft.pt',
#                #'./BuildingFramework/Models/Disentanglement/best-model_youtube_IC_4_VAE_with_rob_base_ft.pt',
#                #'./BuildingFramework/Models/Disentanglement/best-model_youtube_IC_5_VAE_with_rob_base_ft.pt',
#                #'./BuildingFramework/Models/Disentanglement/best-model_youtube_IC_6_VAE_with_rob_base_ft.pt',
#                #'./BuildingFramework/Models/Disentanglement/best-model_youtube_IC_8_VAE_with_rob_base_ft.pt',
#                #'./BuildingFramework/Models/Disentanglement/best-model_youtube_IC_11_VAE_with_rob_base_ft.pt',
#                 './BuildingFramework/Models/Disentanglement/best-model_youtube_IC_10_VAE_with_rob_base_ft_balancedData.pt',
#               # './BuildingFramework/Models/Disentanglement/best-model_youtube_IC_1_VAE_with_rob_base_NoFT_balancedData.pt',
#                 './BuildingFramework/Models/Disentanglement/best-model_youtube_IC_2_VAE_with_rob_base_NoHL_balancedData.pt',
#                 './BuildingFramework/Models/Disentanglement/best-model_youtube_IC_7_VAE_with_rob_base_NoTL_balancedData.pt',
#                 './BuildingFramework/Models/Disentanglement/best-model_youtube_IC_2_VAE_with_rob_base_NoBL_balancedData.pt']

# model_files = ["D:/Hate Speech/Models/Disentanglement/WeakSupervision/best-model_gab_HX_24_VAE_with_rob_base_WS_OURS.pt",
#                "D:/Hate Speech/Models/Disentanglement/WeakSupervision/best-model_reddit_TRY_6_VAE_with_rob_base_WS_OURS.pt",
#                "D:/Hate Speech/Models/Disentanglement/WeakSupervision/best-model_twitter_HX_6_VAE_with_rob_base_WS_OURS.pt",
#                "D:/Hate Speech/Models/Disentanglement/WeakSupervision/best-model_youtube_IC_10_VAE_with_rob_base_WS_OURS.pt"]

model_files = ["D:/Hate Speech/Models/Disentanglement/WeakSupervision/best-model_Twitter_New_1_VAE_with_rob_base_WS_OURS.pt",
                "D:/Hate Speech/Models/Disentanglement/WeakSupervision/best-model_Reddit_New_4_VAE_with_rob_base_WS_OURS.pt"]

test_files =['D:/Hate Speech/Preprocessed Datasets/GAB/gab_HX_Numeric_test.csv',
          'D:/Hate Speech/Preprocessed Datasets/RedditNew/RedditNew_Test.csv',
           'D:/Hate Speech/Preprocessed Datasets/TwitterNew/TwitterNew_Test.csv',
        #    './Preprocessed Datasets/Twitter/twitter_TRY_Numeric_test.csv',
        #    './Preprocessed Datasets/Twitter/twitter_Numeric_test.csv',
             'D:/Hate Speech/Preprocessed Datasets/Youtube/youtube_Numeric_test.csv']
        #     'D:/Hate Speech/Preprocessed Datasets/Youtube/youtube_TRY_Numeric_test.csv']
        #    'D:/Hate Speech/Preprocessed Datasets/Youtube/youtube_IC_Numeric_test.csv']
# test_files = ['./reddit_TRY_Numeric_test_subset.csv']

for f in model_files:
        print(f)
        print("*"*100)
        for f1 in range(0,len(test_files)):
                print(test_files[f1])
                print("*"*100)
                hate_model.load_state_dict(torch.load(f))
                test_frame =  pd.read_csv(test_files[f1])
                evaluate(model=hate_model, test_data = test_frame['text'].values.tolist(), test_labels=test_frame['label'].values.tolist(),target_labels=[""]*len(test_frame['label'].values.tolist()), encoder_tokenizer=encoder_tokenizer,decoder_tokenizer=decoder_tokenizer,params1=params1)

In [None]:
target_label_dict = {"Race":0,"Religion":1,"Sexuality and Sexual Preferences":2,"Gender":3,"Immigration Status":4,"Nationality":5,"Ableness/Disability":6,"Class":7}

In [None]:
class _Inference(nn.Sequential):
    def __init__(self, num_input_channels, latent_dim, disc_variable=True):
        super(_Inference, self).__init__()
        if disc_variable:
            self.add_module('fc', nn.Linear(num_input_channels, num_input_channels//2))
            self.add_module('relu', nn.ReLU())
            self.add_module('fc2', nn.Linear(num_input_channels//2, latent_dim))
            self.add_module('log_softmax', nn.LogSoftmax(dim=1))
        else:
            self.add_module('fc', nn.Linear(num_input_channels, latent_dim))

class Sample(nn.Module):
    def __init__(self, temperature):
        super(Sample, self).__init__()
        self._temperature = temperature

    def forward(self, norm_mean, norm_log_sigma, norm_mean_t, norm_log_sigma_t, disc_label=None, mixup=False, disc_label_mixup=None,
                mixup_lam=None):
        """
        :param norm_mean: mean parameter of continuous norm variable
        :param norm_log_sigma: log sigma parameter of continuous norm variable
        :param disc_log_alpha: log alpha parameter of discrete multinomial variable
        :param disc_label: the ground truth label of discrete variable (not one-hot label)
        :param mixup: if we do mixup
        :param disc_label_mixup: the mixup target label
        :param mixup_lam: the mixup lambda
        :return: sampled latent variable
        """
        batch_size = norm_mean.size(0)
        latent_sample = list([])
        latent_sample.append(self._sample_norm(norm_mean, norm_log_sigma))
        latent_sample.append(self._sample_norm(norm_mean_t, norm_log_sigma_t))
        latent_sample = torch.cat(latent_sample, dim=1)
        dim_size = latent_sample.size(1)
        latent_sample = latent_sample.view(batch_size, dim_size, 1, 1)
        return latent_sample

    def _sample_gumbel_softmax(self, log_alpha):
        """
        Samples from a gumbel-softmax distribution using the reparameterization
        trick.

        Parameters
        ----------
        log_alpha : torch.Tensor
            Parameters of the gumbel-softmax distribution. Shape (N, D)
        """
        EPS = 1e-12
        unif = torch.rand(log_alpha.size()).cuda()
        gumbel = -torch.log(-torch.log(unif + EPS) + EPS)
        # Reparameterize to create gumbel softmax sample
        logit = (log_alpha + gumbel) / self._temperature
        return torch.softmax(logit, dim=1)

    @staticmethod
    def _sample_norm(mu, log_sigma):
        """
        :param mu: the mu for sampling with N*D
        :param log_sigma: the log_sigma for sampling with N*D
        Return the latent normal sample z ~ N(mu, sigma^2)
        """
        std_z = torch.randn(mu.size())
        if mu.is_cuda:
            std_z = std_z.cuda()

        return mu + torch.exp(log_sigma) * std_z

class HATE_WATCH(nn.Module):
    """The VAD-VAE model."""
    def __init__(self, temperature,gamma,enc_chk,check_point, dec_chk,bart_check_point,num_targets, num_class, beta_c,beta_d, device, batch_size, decoder_type, x_sigma=1):
        super(HATE_WATCH, self).__init__()
        self.device = device
        self.batch_size = batch_size
        self.decoder_type = decoder_type
        self.temperature = nn.Parameter(torch.tensor(temperature, requires_grad=True, device=device))

        #Prepare the encoder and decoder for VAE.
        # self.encoder = RobertaForSequenceClassification.from_pretrained(check_point).to(device)
        if(enc_chk!= ""):
            self.encoder = RobertaForSequenceClassification.from_pretrained(check_point).to(device)
            self.encoder.load_state_dict(torch.load(enc_chk))
            self.encoder = self.encoder.roberta
        else:
            self.encoder = RobertaModel.from_pretrained(check_point).to(device)
        for param in self.encoder.base_model.parameters():
            param.requires_grad = False
        self.tokenizer = RobertaTokenizer.from_pretrained(check_point)
        self.config = AutoConfig.from_pretrained(check_point)

        self.cmi = torch.nn.Parameter(torch.rand(1, requires_grad=True, device=device))
        self.dmi = torch.nn.Parameter(torch.rand(1, requires_grad=True, device=device))

        hidden_size = self.config.hidden_size
        if decoder_type == 'BART':
            self.decoder = BartDecoder.from_pretrained(bart_check_point).to(device)
            # self.decoder = BartForSequenceClassification.from_pretrained(bart_check_point).to(device)
            # self.decoder.load_state_dict(torch.load(dec_chk))
            # self.decoder = self.decoder.model.decoder
        else:
            self.decoder = nn.LSTM(hidden_size, hidden_size, 1, batch_first=True)

        self.decoder_start_token_id = 2
        self.lm_head = nn.Linear(hidden_size, self.config.vocab_size)
        # self.target_labeler = FCLayer(hidden_size,num_targets,0.2)
        self.lm_loss_fn = nn.CrossEntropyLoss(reduction='none')
        self.x_sigma = x_sigma
        self.relu = nn.ReLU()
        self.gamma = gamma
        self.compressor = nn.Linear(2*hidden_size,hidden_size)

        self.dropout_rate = params1.classifier_dropout

        #Prepare the disentanglement modules.
        self.continuous_inference = nn.Sequential()
        self.disc_latent_inference = nn.Sequential()
        conti_mean_inf_module = _Inference(num_input_channels=hidden_size,
                                           latent_dim=hidden_size,
                                           disc_variable=False)
        conti_logsigma_inf_module = _Inference(num_input_channels=hidden_size,
                                               latent_dim=hidden_size,
                                               disc_variable=False)
        self.continuous_inference.add_module("mean", conti_mean_inf_module)
        self.continuous_inference.add_module("log_sigma", conti_logsigma_inf_module)

        conti_mean_inf_module_t = _Inference(num_input_channels=hidden_size,
                                           latent_dim=hidden_size,
                                           disc_variable=False)
        conti_logsigma_inf_module_t = _Inference(num_input_channels=hidden_size,
                                               latent_dim=hidden_size,
                                               disc_variable=False)
        self.disc_latent_inference.add_module("mean", conti_mean_inf_module_t)
        self.disc_latent_inference.add_module("log_sigma", conti_logsigma_inf_module_t)
        # self._disc_latent_dim = num_targets
        # dic_inf = _Inference(num_input_channels=hidden_size, latent_dim=self._disc_latent_dim,
        #                      disc_variable=True)
        # self.disc_latent_inference = dic_inf
        sample = Sample(temperature=self.temperature)
        self.sample = sample

        self.kl_beta_c = beta_c
        self.kl_beta_d = beta_d

        #self.disc_log_prior_param = torch.log(torch.tensor([1 / self._disc_latent_dim for i in range(self._disc_latent_dim)]).view(1, -1).float().cuda())


        #Reconstructor
        self.reconstructor = nn.Linear(hidden_size, hidden_size)
        self.batch_norm = nn.BatchNorm1d(num_features = hidden_size)
        #Hate Classifier
        self.hate_classifier = nn.Sequential(
            nn.Dropout(params1.classifier_dropout),
            nn.Linear(hidden_size,hidden_size),
            nn.Tanh(),
            nn.Dropout(params1.classifier_dropout),
            nn.Linear(hidden_size, num_class),
            nn.Softmax(dim=1)
        )

        #Target Classifier
        self.target_classifier = nn.Sequential(
            nn.Dropout(params1.classifier_dropout),
            nn.Tanh(),
            nn.Linear(hidden_size,num_targets)
        )
        self.mse_loss = nn.MSELoss(reduction='sum')
        # self.softmax = nn.Softmax(dim=1)

    def get_contrastive_loss(self,embeddings,labels):
      sf = nn.Softmax(dim=1)
      softmax_labels = sf(labels)
      print("LABELS: ",softmax_labels)
      max_indices = torch.argmax(softmax_labels, dim=1)
      weights = (max_indices.unsqueeze(1) == max_indices.unsqueeze(0)).float()
      print("WEIGHTS:",weights)
      tensor_expanded = embeddings.unsqueeze(1) - embeddings.unsqueeze(0)
      distance_matrix = torch.sqrt(torch.sum(tensor_expanded ** 2, dim=-1))/(embeddings.shape[1])
      loss = torch.mean(weights * distance_matrix.pow(2) + (1 - weights) * torch.clamp_min(self.gamma - distance_matrix, 0).pow(2),axis=1)
      return torch.mean(loss)

    def get_lm_loss(self, logits, labels, masks):
        '''Get the utterance reconstruction loss.'''
        # labels = labels.float()
        loss = self.lm_loss_fn(logits.view(-1, self.config.vocab_size), labels.view(-1))
        masked_loss = loss * masks.view(-1)
        return torch.mean(masked_loss)


    def forward(self, inputs, mask, decoder_inputs, decoder_masks, decoder_labels, vad_labels, labels):
        """
        :param inputs: The input of PLM. Dim: [B, seq_len]
        :param mask: The mask for input x. Dim: [B, seq_len]
        """
        '''decoder_input_ids = shift_tokens_right(
            x, self.config.pad_token_id, self.decoder_start_token_id
        )'''
        x = self.encoder(inputs, attention_mask=mask)[0]
        x = x[:, 0, :].squeeze(1)
        x = self.batch_norm(x)

        #Get the latent variables.
        norm_mean = self.continuous_inference.mean(x)
        norm_log_sigma = self.continuous_inference.log_sigma(x)
        norm_mean_t = self.disc_latent_inference.mean(x)
        norm_log_sigma_t = self.disc_latent_inference.log_sigma(x)
        latent_sample = self.sample(norm_mean, norm_log_sigma, norm_mean_t,norm_log_sigma_t)
        t_sample = self.sample._sample_norm(norm_mean_t, norm_log_sigma_t)
        t_logits = self.target_classifier(t_sample)
        con_loss = self.get_contrastive_loss(t_sample,t_logits)
        print(con_loss)

        latent_sample = torch.squeeze(latent_sample)
        print(latent_sample.shape)
        latent_sample = self.compressor(latent_sample)
        decoder_hidden = self.reconstructor(latent_sample)

        if self.decoder_type == 'BART':
            decoder_outputs = self.decoder(
                input_ids=decoder_inputs,
                attention_mask=decoder_masks,
                encoder_hidden_states=decoder_hidden.unsqueeze(1))
            lm_logits = self.lm_head(decoder_outputs.last_hidden_state)
        else:
            input_embeddings = self.encoder.embeddings(decoder_inputs)
            h = decoder_hidden.unsqueeze(0)
            decoder_outputs, (_, _) = self.decoder(input_embeddings, (h, torch.zeros(h.shape).to(self.device)))
            lm_logits = self.lm_head(decoder_outputs)

        # reconstruct_loss = F.mse_loss(lm_logits, x, reduction="mean") / (2 * self.batch_size * (self.x_sigma ** 2))
        # preds = F.gumbel_softmax(lm_logits, tau=1, hard=False)
        reconstruct_loss = self.get_lm_loss(lm_logits, decoder_labels, decoder_masks)#/ (2 * self.batch_size * (self.x_sigma ** 2))


        # calculate latent space KL divergence
        z_mean_sq = norm_mean * norm_mean
        z_log_sigma_sq = 2 * norm_log_sigma
        z_sigma_sq = torch.exp(z_log_sigma_sq)
        continuous_kl_loss = 0.5 * torch.sum(z_mean_sq + z_sigma_sq - z_log_sigma_sq - 1) / self.batch_size
        # notice here we duplicate the 0.5 by each part
        # disc param : log(a1),...,log(an) type
        # disc_kl_loss = torch.sum(torch.exp(disc_log_alpha) * (disc_log_alpha - self.disc_log_prior_param)) / self.batch_size

        prior_kl_loss_l = self.kl_beta_c * torch.abs(continuous_kl_loss - self.cmi) #+ self.kl_beta_d * torch.abs(disc_kl_loss - self.dmi)
        elbo_loss_l = reconstruct_loss + prior_kl_loss_l

        #Calculate the classification loss.
        hate_sample = self.sample._sample_norm(norm_mean, norm_log_sigma)
        #target_sample = self.sample._sample_gumbel_softmax(disc_log_alpha)

        # hate_sample = F.dropout(hate_sample, p=self.dropout_rate, training=self.training)
        # target_sample = self.get_vad_loss(target_sample,vad_labels)
        hate_logits = self.hate_classifier(hate_sample)

        # target_labels = torch.zeros_like(vad_labels).float().cuda()
        # target_labels[target_logits] = 1

        # print("Reconstruction loss",reconstruct_loss,"KL Continous Loss",self.kl_beta_c * torch.abs(continuous_kl_loss - self.cmi),"KL Discrete loss",self.kl_beta_d * torch.abs(disc_kl_loss - self.dmi),end=",")
        return elbo_loss_l[0], hate_logits, con_loss

In [None]:
import pandas as pd


# gpt_data = pd.read_csv("D:/Hate Speech/Preprocessed Datasets/YouTube/GPT4Analysis.csv")
# true_data_tr = pd.read_csv("D:/Hate Speech/Preprocessed Datasets/YouTube/youtube_Numeric.csv")
# true_data_te = pd.read_csv("D:/Hate Speech/Preprocessed Datasets/YouTube/youtube_IC_Numeric_balanced_test.csv")

gpt_data = pd.read_csv("D:/Hate Speech/Preprocessed Datasets/GAB/GPT4Analysis.csv")
true_data_tr = pd.read_csv("D:/Hate Speech/Preprocessed Datasets/GAB/gab_HX_Numeric_train.csv")
true_data_te = pd.read_csv("D:/Hate Speech/Preprocessed Datasets/GAB/gab_HX_Numeric_test.csv")


data1 = pd.merge(gpt_data,true_data_tr,on=['text'])
data2 = pd.merge(gpt_data,true_data_te,on=['text'])
data1.shape,true_data_tr.shape,gpt_data.shape,true_data_te.shape,data2.shape

In [None]:
target_label_dict = {"Race":0,"Religion":1,"Sexuality and Sexual Preferences":2,"Gender":3,"Immigration Status":4,"Nationality":5,"Ableness/Disability":6,"Class":7}

data1['pred_target1'] = data1['pred_target'].map(target_label_dict)
data1.dropna(subset=['pred_target1'], inplace=True)
data1['pred_target1'].astype('int32')
data1 = data1[['text','label','pred_target1','target_y']]
data1.columns = ['text','label','target','tgt_og']
data1.drop_duplicates(inplace=True)
data1.shape
data1.to_csv("D:/Hate Speech/Preprocessed Datasets/GAB/gab_hx_Numeric_train_WS.csv",index=False)


data2['pred_target1'] = data2['pred_target'].map(target_label_dict)
data2.dropna(subset=['pred_target1'], inplace=True)
data2['pred_target1'].astype('int32')
data2 = data2[['text','label','pred_target1','target_y']]
data2.columns = ['text','label','target','tgt_og']
data2.drop_duplicates(inplace=True)
data2.shape
data2.to_csv("D:/Hate Speech/Preprocessed Datasets/GAB/gab_hx_Numeric_test_WS.csv",index=False)
data1.shape,data2.shape