In [None]:
!mkdir models

In [None]:
import os
import subprocess
import random
import pickle
import re
import time
import math
import itertools

import numpy as np 
import pandas as pd
import matplotlib
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torch.nn as nn
from torch.utils.data import TensorDataset, ConcatDataset
import transformers
from transformers import BertForSequenceClassification, BertPreTrainedModel, BertConfig, BertModel, BertTokenizer
import sklearn
import nltk

In [None]:
if torch.cuda.is_available():        
    device = torch.device("cuda")
    print('There are %d GPU(s) available.' % torch.cuda.device_count())
    print('We will use the GPU:', torch.cuda.get_device_name(0))
else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")
print(device)

In [None]:
class config():
    BATCH_SIZE = 8 # start with a small size
    VAL_BATCH_SIZE = 8
    RANDOM_SEED = 41
    PATH_TO_BERT = '../post_trained_model' #path to BERT weights
    PATH_TO_SAVE = './models'
    EPOCHS = 6
    LR = 3e-5
    LABELS = 3 # number of main task class labels
    # computed wieghts for every class label of our three tasks. pass to loss function. Output of the compute_weights function in Utils cell
    STANCE_WGT = [2.021585720215857, 3.877388535031847, 4.04149377593361]
    VAL_WGT = [1,1,1]
    SENTI_WGT = [6.163498098859316, 9.765060240963855, 1.3598993288590604]
    ABSA_WGT = [3.267559621038876, 5.474548440065681, 1.9558075870160345]
    DROPOUT = 0.1
    ABSA_LABELS = 3
    SENTI_LABELS = 3
    FREEZING_LAYERS = [5,6,7,8]
    WARMUP_PROP = 0.1
    NUMBER_OF_SAMPLES = None # you can initialize it with None. 

tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased', do_lower_case=False)

In [None]:
def seed_everything(seed=config.RANDOM_SEED):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    np.random.seed(seed)
seed_everything()

In [None]:
class Prepare_Input():
  def __init__(self, tokenizer,max_seq_length):
        self.tokenizer = tokenizer
        self.max_seq_length = max_seq_length

  def _get_masks(self, tokens):
      """Mask for padding."""
      if len(tokens)> self.max_seq_length:
          print(tokens)
          raise IndexError("Token length more than max seq length!")
      return [1]*len(tokens) + [0] * (self.max_seq_length - len(tokens))

  def _get_segments(self, tokens):
      """Segments: 0 for the first sequence, 1 for the second."""
      if len(tokens) > self.max_seq_length:
          print(tokens)
          raise IndexError("Token length more than max seq length!")
      segments = []
      first_sep = True
      current_segment_id = 0
      for token in tokens:
          segments.append(current_segment_id)
          if token == "[SEP]":
              if first_sep:
                  first_sep = False 
              # else:
                  current_segment_id = 1
      return segments + [0] * (self.max_seq_length - len(tokens))

  def _get_ids(self, tokens):
      """Token ids from Tokenizer vocab."""
      token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
      input_ids = token_ids + [0] * (self.max_seq_length-len(token_ids))
      return input_ids

  def _trim_input(self, text, target, t_max=500):
      t = self.tokenizer.tokenize(text)
      q = self.tokenizer.tokenize(target)
      t_len = len(t)
      q_len = len(q)
      if t_len > t_max:
          t = t[:t_max]
          t_len = t_max
      if (t_len+q_len+3) > self.max_seq_length:
        q_new_len = self.max_seq_length - (t_len + 3)
        q = q[:q_new_len]
      return t, q

  def _convert_to_bert_inputs(self, title, question):
      """Convert tokenized input to ids, masks and segments for BERT."""

      stoken = ["[CLS]"] + title + ["[SEP]"] + question + ["[SEP]"] 
      input_ids = self._get_ids(stoken)
      input_masks = self._get_masks(stoken)
      input_segments = self._get_segments(stoken)
      return [input_ids, input_masks, input_segments]

  def compute_input_arays(self, df, auxSent):
      input_ids, input_masks, input_segments = [], [], []
      for _, instance in tqdm(df.iterrows()):
          t, q = instance.text, auxSent
          t, q = self._trim_input(t, q)
          ids, masks, segments = self._convert_to_bert_inputs(t, q)
          input_ids.append(ids)
          input_masks.append(masks)
          input_segments.append(segments)
      return [torch.tensor(input_ids), 
              torch.tensor(input_masks), 
              torch.tensor(input_segments)]

  def compute_input_arays_byCol(self, df, firstCol, secondCol):
      input_ids, input_masks, input_segments = [], [], []
      for _, instance in tqdm(df.iterrows()):
          t, q = instance[firstCol], instance[secondCol]
          t_t, q_q = self._trim_input(t, q)
          # print(t)
          ids, masks, segments = self._convert_to_bert_inputs(t_t, q_q)
          input_ids.append(ids)
          input_masks.append(masks)
          input_segments.append(segments)
      return [torch.tensor(input_ids), 
              torch.tensor(input_masks), 
              torch.tensor(input_segments)]

  def compute_output_arrays(self, df, columns):
      return np.asarray(df[columns])

In [None]:
#Utils

from nltk.corpus import stopwords

#Persian stopwords list is obtained from https://github.com/sobhe/hazm/blob/master/hazm/data/stopwords.dat
#A few words of the list are removed and some are added
STOPWORD_LIST = ['و','در', 'به', 'از', 'که', 'این', 'را', 'با', 'است', 'برای', 'آن', 'یک', 'خود','تا', 'کرد', 'بر', 'هم', 'نیز', 'گفت', 'می\u200cشود',
                 'وی', 'شد', 'دارد', 'ما', 'یا', 'شده', 'باید', 'هر','آنها', 'بود', 'او', 'دیگر', 'دو', 'مورد', 'می\u200cکند', 'شود', 'کند', 'بین', 'پیش',
                 'شده_است', 'پس', 'نظر','اگر', 'هستند', 'من', 'کنند', 'باشد', 'چه', 'می', 'بخش', 'می\u200cکنند', 'همین', 'افزود', 'هایی', 'دارند', 'راه', 
                 'همچنین','روی', 'داد', 'داشت', 'سوی', 'میان', 'اینکه', 'شدن', 'بعد', 'کردن', 'برخی', 'کردند', 'می\u200cدهد', 'کرده_است', 'نسبت', 'شما', 
                 'چنین', 'طور', 'افراد', 'درباره', 'بار', 'می\u200cتواند', 'کرده', 'چون', 'طی', 'همان', 'آنان', 'می\u200cگوید', 'دیگری','خواهد_شد', 'کنیم',
                 'قابل', 'یعنی', 'می\u200cتوان', 'وارد', 'قبل', 'براساس', 'نیاز', 'گذاری', 'سازی', 'بوده_است', 'می\u200cشوند','وقتی', 'گرفت', 'جای', 'حالی',
                 'تغییر', 'پیدا', 'اکنون', 'تحت', 'باعث', 'مدت', 'فقط', 'تعداد', 'آیا', 'بیان', 'رو', 'شدند','کرده_اند', 'بودن', 'نوع', 'جاری', 'دهد', 'برابر',
                 'بوده', 'مربوط', 'امر', 'گیری', 'خصوص', 'آقای', 'اثر', 'کننده', 'بودند','فکر', 'کنار', 'سایر', 'کنید', 'ضمن', 'مانند', 'باز', 'می\u200cگیرد', 
                 'حل', 'پی', 'مثل', 'می\u200cرسد','اجرا', 'منظور', 'کسی', 'موجب', 'طول', 'امکان', 'آنچه', 'تعیین', 'گفته', 'شوند', 'جمع', 'گونه', 'تاکنون', 'رسید',
                 'ساله', 'گرفته', 'شده_اند', 'علت', 'داشته_باشد', 'خواهد_بود', 'طرف', 'تهیه', 'تبدیل', 'زیرا', 'می\u200cتوانند', 'بخشی', 'باشند', 'داده_است', 'حد',
                 'کسانی', 'می\u200cکرد', 'داریم', 'می\u200cباشد', 'دانست', 'ناشی', 'داشتند', 'دهه', 'می\u200cشد', 'ایشان', 'آنجا', 'گرفته_است','می\u200cآید', 'لحاظ',
                 'آنکه', 'داده', 'هستیم','اند', 'برداری', 'می\u200cکنیم', 'نشست', 'سهم', 'همیشه', 'آمد', 'اش', 'وگو', 'می\u200cکنم', 'طبق', 'جا', 'خواهد_کرد',
                 'نوعی', 'چگونه', 'رفت', 'هنگام', 'فوق','روش', 'سعی', 'بندی', 'شمار', 'مواجه', 'همچنان', 'سمت', 'داشته_است', 'چیز', 'پشت', 'آورد', 'حالا', 'روبه',
                 'سال\u200cهای','دادند', 'می\u200cکردند','عهده', 'جایی', 'دیگران', 'بروز', 'یکدیگر', 'آمده_است', 'کنم', 'سپس', 'کنندگان', 'خودش', 'همواره', 'یافته',
                 'شان', 'صرف', 'نمی\u200cشود', 'رسیدن', 'یابد', 'متر', 'ساز','داشته', 'کرده_بود', 'باره', 'نحوه', 'کردم', 'تو', 'شخصی', 'داشته_باشند', 'محسوب', 'پخش',
                 'داشتن', 'نظیر', 'آمده', 'گروهی', 'فردی', 'ع', 'همچون', 'خویش', 'کدام', 'دسته', 'سبب','عین', 'آوری', 'متاسفانه', 'بیرون','دار', 'ابتدا', 'افرادی',
                 'می\u200cگویند', 'سالهای', 'درون', 'نیستند', 'یافته_است', 'پر', 'خاطرنشان', 'گاه', 'جمعی', 'دوباره', 'می\u200cیابد','لذا', 'زاده', 'گردد', 'اینجا','ها','های','ی','یه','ای']
STOP_SET_PER = set(STOPWORD_LIST)
# nltk.download('stopwords')
# STOP_WORD_EN = stopwords.words('english')
def remove_emoji(text):
    emoji_pattern = re.compile("["
                           u"\U0001F600-\U0001F64F"  # emoticons
                           u"\U0001F300-\U0001F5FF"  # symbols & pictographs
                           u"\U0001F680-\U0001F6FF"  # transport & map symbols
                           u"\U0001F1E0-\U0001F1FF"  # flags (iOS)
                           u"\U00002702-\U000027B0"
                           u"\U000024C2-\U0001F251"
                           "]+", flags=re.UNICODE)
    return emoji_pattern.sub(r'', text)

def normalize_text(inp):
  url = 'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+'
  tmp = re.sub(url,'',inp)
  tmp = re.sub(r'(https:\/\/|http:\/\/)?(\w|\.|\/|\?|\=|\&|\%)*\b','', tmp)
  tmp = re.sub('\n',' ',inp)
  tmp = re.sub('\r','',inp)
  tmp = re.sub(r'(RT|rt)[ ]*@[ ]*[\S]+',r'',tmp)
  tmp = re.sub(r'[!"\$%&\'()*+,\-.\/:;=#@؟?\[\\\]^_`{|}~]*','', tmp)
  tmp = remove_emoji(tmp)
  tmp = re.sub(r'@[\w_-]+','@PERSON', tmp)
  tmp = re.sub(r'[ ]+',r' ',tmp)
  return tmp
    
def normalize_text_per(inp, stop_words=  STOP_SET_PER):
    url = 'http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+'
    tmp = re.sub(url,'',inp)
    tmp = re.sub('\n',' ',tmp)
    tmp = re.sub('\r','',tmp)
    tmp = re.sub(r'(RT|rt)[ ]*@[ ]*[\S]+',r'',tmp)
    tmp = remove_emoji(tmp)
    words = [word for word in tmp.strip().replace('\n',' ').split(' ') if word not in stop_words]
    tmp = ' '.join(words)
    tmp = re.sub(r'@[\w_-]+','@PERSON', tmp)
    tmp = re.sub(r'[ ]+',r' ',tmp)
    return tmp

def compute_max_len(columnVals, auxSent = '', tokenizer = tokenizer):
  """Compute maximum length of a dataframe column values based on BERT tokenizer."""
  max_len = 0
  for i,sent in enumerate(columnVals):
      t2 = sent + auxSent
      input_ids = tokenizer.encode(t2, add_special_tokens=True)
      max_len = max(max_len, len(input_ids))
  print('Max sentence length: ', max_len)
  return max_len

def compute_weights(df, colName):
  """Compute loss weights."""
  _, counts = np.unique(df[colName], return_counts=True)
  class_weights = [sum(counts) / c for c in counts]
  print(class_weights)
  return class_weights

In [None]:
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler


class MultiTaskSampler(torch.utils.data.sampler.Sampler):
    """Iterate over tasks and provide a balanced batch per task. Every iteration of the training consists of three sub-iteration. A sub-iteration is dedicated to one task."""
    def __init__(self, dataset, batch_size, samples_num):
        self.dataset = dataset
        self.batch_size = batch_size
        self.number_of_datasets = len(dataset.datasets)
        if samples_num:
          self.mainTarget_dataset_size = samples_num 
        else:
          self.mainTarget_dataset_size = max([len(cur_dataset) for cur_dataset in dataset.datasets])

    def __len__(self):
        return self.batch_size * math.ceil(self.mainTarget_dataset_size / self.batch_size) * len(self.dataset.datasets)

    def __iter__(self):
        samplers_list = []
        sampler_iterators = []
        for dataset_idx in range(self.number_of_datasets):
            cur_dataset = self.dataset.datasets[dataset_idx]
            sampler = RandomSampler(cur_dataset)
            samplers_list.append(sampler)
            cur_sampler_iterator = sampler.__iter__()
            sampler_iterators.append(cur_sampler_iterator)

        push_index_val = [0] + self.dataset.cumulative_sizes[:-1]
        step = self.batch_size * self.number_of_datasets
        samples_to_grab = self.batch_size
        epoch_samples = self.mainTarget_dataset_size * self.number_of_datasets

        final_samples_list = []  
        for _ in range(0, epoch_samples, step):
            for i in range(self.number_of_datasets):
                cur_batch_sampler = sampler_iterators[i]
                cur_samples = []
                for _ in range(samples_to_grab):
                    try:
                        cur_sample_org = cur_batch_sampler.__next__()
                        cur_sample = cur_sample_org + push_index_val[i]
                        cur_samples.append(cur_sample)
                    except StopIteration:
                        sampler_iterators[i] = samplers_list[i].__iter__()
                        cur_batch_sampler = sampler_iterators[i]
                        cur_sample_org = cur_batch_sampler.__next__()
                        cur_sample = cur_sample_org + push_index_val[i]
                        cur_samples.append(cur_sample)
                final_samples_list.extend(cur_samples)

        return iter(final_samples_list)

In [None]:
from tqdm.notebook import tqdm

#### SemEval Dataset
!gdown --id 1_dgLPsScUlWM8nHGzHVyTU_vA5CPKBNL 

def prepare_dataFrames_semEval(file_path):
  import pandas as pd
  final = []
  f = open(file_path, newline='\n', encoding='utf8')
  lines = f.readlines()
  for i,line in enumerate(lines):
    tmp = line.strip().split('\t')
    tmp[2] = normalize_text(tmp[2])
    tmp[3] = int(tmp[3]) 
    final.append(tmp)
  df = pd.DataFrame(final,columns=['id','target','text','stance'])
  return df

semEval = prepare_dataFrames_semEval('all-woMention.txt')
train_dataset_obj = Prepare_Input(tokenizer,76)
input_ids_train ,attention_masks_train, token_type_ids_train = train_dataset_obj.compute_input_arays_byCol(semEval, 'text', 'target')
labels_semEval = torch.tensor(semEval.stance.values.astype(np.int64))
train_dataset = TensorDataset(input_ids_train, token_type_ids_train, attention_masks_train, labels_semEval)

#### SentiFars Dataset
!gdown --id 1D9LwNrgtP8nAs39zs_Ic4S3du_TaD4vz

def prepare_sentifars(file_path):
  import pandas as pd
  final = []
  f = open(file_path, newline='', encoding='utf8')
  lines = f.readlines()
  for i,line in enumerate(lines):
    tmp = line.strip().split('\t')
    tmp[0] = normalize_text_per(tmp[0])
    tmp[1] = int(tmp[1])
    final.append(tmp)
  df = pd.DataFrame(final,columns=['text','sentiment'])
  return df
senti_df = prepare_sentifars('sentifars.txt')
senti_df = senti_df.sample(frac=1).reset_index(drop=True)

encoded_data_senti = tokenizer.batch_encode_plus(
    senti_df.text.values, 
    add_special_tokens=True, 
    return_attention_mask=True, 
    pad_to_max_length=True, 
    max_length=512, 
    return_tensors='pt'
)
input_ids_senti = encoded_data_senti['input_ids']
token_type_ids_senti = encoded_data_senti['token_type_ids']
attention_masks_senti = encoded_data_senti['attention_mask']
labels_senti = torch.tensor(senti_df.sentiment.values.astype(np.int64))
senti_dataset = TensorDataset(input_ids_senti,token_type_ids_senti, attention_masks_senti, labels_senti)

#### Pars-ABSA Dataset
!gdown --id 1fgD7fDTLP7uLuIDM1nm2Wtn4FcOENWWK

def change_string_to_float_label_absa(df):
  labels = {'negative':0, 'neutral':1, 'positive':2}
  df.loc[df['polarity'] == 'negative', 'polarity'] = labels['negative']
  df.loc[df['polarity'] == 'neutral', 'polarity'] = labels['neutral']
  df.loc[df['polarity'] == 'positive', 'polarity'] = labels['positive']
  return df
absa = pd.read_csv('Pars-ABSA.csv')
absa = change_string_to_float_label_absa(absa)
absa['text'] = absa.apply (lambda row: normalize_text_per(row['text']) , axis=1)

absa_dataset = Prepare_Input(tokenizer,512)
input_ids_absa ,attention_masks_absa, token_type_ids_absa = absa_dataset.compute_input_arays_byCol(absa, 'text', 'aspect')
labels_absa = torch.tensor(absa.polarity.values.astype(np.int64))
absa_dataset = TensorDataset(input_ids_absa,token_type_ids_absa, attention_masks_absa, labels_absa)

####  CONCAT datasets
concat_train = ConcatDataset([train_dataset, senti_dataset, absa_dataset])
custom_dataloader = DataLoader(
            concat_train, 
            sampler = MultiTaskSampler(dataset=concat_train, batch_size = config.BATCH_SIZE, samples_num = config.NUMBER_OF_SAMPLES), 
            batch_size = config.BATCH_SIZE, 
            shuffle = False
        )

In [None]:
from sklearn.model_selection import StratifiedKFold

#### persian Dataset
#### Replace the <FILE_ID> with the id of the persian dataset your are provided by e-mail. 
!gdown --id <FILE_ID>  

def change_string_to_float_label(df):
  labels = {'AGAINST':0, 'NEITHER':1, 'FAVOR':2}
  df.loc[df['majorityLabel'] == 'AGAINST', 'majorityLabel'] = labels['AGAINST']
  df.loc[df['majorityLabel'] == 'NEITHER', 'majorityLabel'] = labels['NEITHER']
  df.loc[df['majorityLabel'] == 'FAVOR', 'majorityLabel'] = labels['FAVOR']
  df.loc[df['target'] == 'barjam', 'target'] = 'برجام'
  df.loc[df['target'] == 'trump', 'target'] = 'Donald Trump'
  df.loc[df['target'] == 'raeesi', 'target'] = 'رییسی'
  df.loc[df['target'] == 'barabariJensiati', 'target'] = 'Gender Equality'
  df.loc[df['target'] == 'rouhani', 'target'] = 'حسن روحانی'
  return df

test = pd.read_csv('StancePers(WHOLE) - wo disagreements.tsv', sep = '\t', header=0)
test_whole = change_string_to_float_label(test)
test_whole['text'] = test_whole.apply (lambda row: normalize_text(row['text']) , axis=1)
labels_val_all = torch.tensor(test_whole.majorityLabel.values.astype(np.int64))


#Cross-validation is performed to split persian dataset into validation and test set.
#Since colab restricts usage after one round of training, We do the CV splitting manually at the beginning of every session.
skf = StratifiedKFold(n_splits=3, shuffle = True, random_state = config.RANDOM_SEED)
skf.get_n_splits(test_whole, labels_val_all)
skf_splits = skf.split(test_whole, labels_val_all)
result = next(skf_splits, None)
# result = next(skf_splits, None)
# result = next(skf_splits, None)
test_index, val_index = result[0], result[1]
stancePers_dataset_obj = Prepare_Input(tokenizer,140)

#Input preparation of validation set
input_ids_validation_set ,attention_masks_validation_set, token_type_ids_validation_set = stancePers_dataset_obj.compute_input_arays_byCol(test_whole.iloc[val_index], 'text', 'target')
labels_validation_set = labels_val_all[val_index]
persian_validation_dataset = TensorDataset(input_ids_validation_set,token_type_ids_validation_set, attention_masks_validation_set, labels_validation_set)
persian_validation_dataloader = DataLoader(
            persian_validation_dataset, 
            sampler = SequentialSampler(persian_validation_dataset), 
            batch_size = config.VAL_BATCH_SIZE 
        )

#Input preparation of test set
input_ids_test ,attention_masks_test, token_type_ids_test = stancePers_dataset_obj.compute_input_arays_byCol(test_whole.iloc[test_index], 'text', 'target')
labels_test = labels_val_all[test_index]
persian_test_dataset = TensorDataset(input_ids_test,token_type_ids_test, attention_masks_test, labels_test)
persian_test_dataloader = DataLoader(
            persian_test_dataset, 
            sampler = SequentialSampler(persian_test_dataset), 
            batch_size = config.VAL_BATCH_SIZE 
        )

In [None]:
from torch import nn
from transformers import BertForSequenceClassification, BertPreTrainedModel, BertConfig, BertModel

class Model(nn.Module, config):
    def __init__(self):
        super().__init__()
        self.bert = BertModel.from_pretrained(config.PATH_TO_BERT)
        self.num_features = self.bert.pooler.dense.out_features
        self.stance_labels = config.LABELS
        self.senti_labels = config.SENTI_LABELS
        self.absa_labels = config.ABSA_LABELS
        self.drop = nn.Dropout(config.DROPOUT)
        self.logit = nn.Linear(self.num_features, self.stance_labels)
        self.lstm_senti = nn.LSTM(self.num_features, self.num_features, num_layers=1, bidirectional=True, batch_first=True)
        self.logit_senti = nn.Linear(self.num_features*2, self.senti_labels)
        self.lstm_absa = nn.LSTM(self.num_features, self.num_features, num_layers=1, bidirectional=True, batch_first=True)
        self.logit_absa = nn.Linear(self.num_features*2, self.absa_labels)

    def forward(self, tokens_tensors, segments_tensors, masks_tensors,task_type=None):
        bert_out = self.bert(input_ids=tokens_tensors, token_type_ids=segments_tensors, attention_mask=masks_tensors)
        output = None
        if task_type == 'stance':
          output  = self.logit(self.drop(bert_out[1]))

        elif task_type == 'sentiment':
          enc_hiddens_senti, (last_hidden_senti, last_cell_senti) = self.lstm_senti(bert_out[0])
          output_hidden_senti = torch.cat((last_hidden_senti[0], last_hidden_senti[1]), dim=1)
          ou_senti = self.drop(output_hidden_senti)
          output  = self.logit_senti(ou_senti)
          
        elif task_type == 'absa':
          enc_hiddens_absa, (last_hidden_absa, last_cell_absa) = self.lstm_absa(bert_out[0])
          output_hidden_absa = torch.cat((last_hidden_absa[0], last_hidden_absa[1]), dim=1)
          ou_absa = self.drop(output_hidden_absa)
          output  = self.logit_absa(ou_absa)

        del tokens_tensors
        del segments_tensors
        del masks_tensors
        if device == torch.device("cuda"):
          torch.cuda.empty_cache()
        return output

model = Model().to(device)

In [None]:
# Freeze specific layers of model. This prevents the model from overfitting and catastrophic forgetting.Also reducing the training time.
for layer_idx in config.FREEZING_LAYERS:
  for param in list(model.bert.encoder.layer[layer_idx].parameters()) :
    param.requires_grad = False

In [None]:
### Evaluation

from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report
import matplotlib.pyplot as plt

TASK_CLASSES = {
    'stance': ['against','neither','favor'],
    'sentiment': ['negative','objective','positive'],
    'absa': ['negative', 'neutral', 'positive']}

def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):
    """Print and plot the confusion matrix.
    Normalization can be applied by setting `normalize=True`."""
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)

    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, cm[i, j],
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')


def eval_save(last_favg, step, task_name):
  """Evaluate the model on the dev set and save the best checkpoint."""
  model.eval()
  preds = []
  truths = []
  with torch.no_grad():
      for data in tqdm(persian_validation_dataloader):
        tokens_tensor, segments_tensor, masks_tensor, labels_tensor = [k.to(device) for k in data if k is not None]
        output  = model(tokens_tensor, segments_tensor, masks_tensor,'stance')
        probs = F.softmax(output, dim=-1).detach().cpu().numpy()
        preds += list(np.argmax(probs, axis=1).flatten())
        truths += list(labels_tensor.detach().cpu().numpy().flatten())
      score = classification_report(truths, preds, digits=4, output_dict= True)
      against_score = score['0']['f1-score'] if '0' in score else 0
      fav_score = score['2']['f1-score'] if '2' in score else 0
      f_avg = (against_score + fav_score ) / 2
      if last_favg <= f_avg:
        last_favg = f_avg
        print('saved model at score '+ str(f_avg) + 'and step: ' + str(step))
        torch.save({
            'score':last_favg,
            'step': step,
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'scheduler': scheduler.state_dict(),
            'optimizer_state_dict': optimizer.state_dict()
            }, config.PATH_TO_SAVE + '/stance-best-semEval-CV1.pth')
        just_eval(persian_test_dataloader, 'stance')
      print('_____________ evaluating with persian validation set_____________')
      print(classification_report(truths, preds, digits=4))
      cnf_matrix = confusion_matrix(truths, preds)
      np.set_printoptions(precision=2)
      plt.figure()
      plot_confusion_matrix(cnf_matrix,TASK_CLASSES[task_name], title='Confusion matrix, without normalization')
      plt.figure()
      plt.show()
  return last_favg

def just_eval(input_dataloader, task_name):
  """Evaluate the model on test set for reporting model performance.
  Has nothing to do with model selection."""
  model.eval()
  preds = []
  truths = []
  with torch.no_grad():
      for data in tqdm(input_dataloader):
        tokens_tensor, segments_tensor, masks_tensor, labels_tensor = [k.to(device) for k in data if k is not None]
        output  = model(tokens_tensor, segments_tensor, masks_tensor,task_name)
        probs = F.softmax(output, dim=-1).detach().cpu().numpy()
        preds += list(np.argmax(probs, axis=1).flatten())
        truths += list(labels_tensor.detach().cpu().numpy().flatten())
        #Save true and predicted labels of test set
      with open(config.PATH_TO_SAVE + '/truth-stance-best-semEval-CV1.pickle', 'wb') as tp:
        pickle.dump(truths, tp)  
      tp.close()
      with open(config.PATH_TO_SAVE + '/preds-stance-best-semEval-CV1.pickle', 'wb') as fp:
        pickle.dump(preds, fp)  
      fp.close()
      print('_____________ evaluating with persian test set_____________')
      print(classification_report(truths, preds, digits=4))
      cnf_matrix = confusion_matrix(truths, preds)
      np.set_printoptions(precision=2)
      plt.figure()
      plot_confusion_matrix(cnf_matrix,TASK_CLASSES[task_name], title='Confusion matrix, without normalization')
      plt.figure()
      plt.show()
      

In [None]:
%xmode Plain
%pdb on   
step = 0
last_favg = 0

optimizer = torch.optim.AdamW(model.parameters(), lr=config.LR)
num_warmup_steps = int(config.WARMUP_PROP * config.EPOCHS * (len(custom_dataloader)/3) )
num_training_steps = config.EPOCHS * (len(custom_dataloader)/3)
scheduler = transformers.get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps) 
start_time = time.time()

for epoch in range(config.EPOCHS):
    print('____________________'+ str(epoch) + '_____________________')
    model.train()
    optimizer.zero_grad()
    avg_loss = 0
    loss = 0.0
    for index, data in enumerate(tqdm(custom_dataloader)):             
        tokens_tensor, segments_tensor, masks_tensor, labels_tensor = [k.to(device) for k in data if k is not None]
        #SENTI
        if index % 3 == 1:
          output = model(tokens_tensor, segments_tensor, masks_tensor,'sentiment')
          if device == torch.device("cuda"):
            class_weights = torch.tensor(config.SENTI_WGT, dtype=torch.float).cuda()
          else:
            class_weights = torch.tensor(config.SENTI_WGT, dtype=torch.float)
          loss_fct = torch.nn.CrossEntropyLoss(weight = class_weights)
          senti_loss = loss_fct(output, labels_tensor)
          loss =  0.333 * senti_loss
          loss.backward()

        elif index % 3 == 2:
          #ABSA
          output = model(tokens_tensor, segments_tensor, masks_tensor,'absa')
          if device == torch.device("cuda"):
            class_weights = torch.tensor(config.ABSA_WGT, dtype=torch.float).cuda()
          else:
            class_weights = torch.tensor(config.ABSA_WGT, dtype=torch.float)
          loss_fct = torch.nn.CrossEntropyLoss(weight = class_weights)
          absa_loss = loss_fct(output, labels_tensor)
          loss =  0.333 * absa_loss
          loss.backward()
          optimizer.step()
          scheduler.step()
          model.zero_grad()
          optimizer.zero_grad()
          loss = 0.0
          step = step + 1
          if step % 1000 == 0: # evaluate model on val_set every 1k steps and save the best checkpoint
            last_favg = eval_save(last_favg, step, 'stance') 
            model.train()
        else:
          #### MAIN SD dataset
          # print(tokenizer.decode(tokens_tensor[0]))
          output = model(tokens_tensor, segments_tensor, masks_tensor,  'stance')
          if device == torch.device("cuda"):
            class_weights = torch.tensor(config.STANCE_WGT, dtype=torch.float).cuda()
          else:
            class_weights = torch.tensor(config.STANCE_WGT, dtype=torch.float)
          loss_fct = torch.nn.CrossEntropyLoss(weight = class_weights )
          loss = 0.333 * loss_fct(output, labels_tensor)
          loss.backward()
          
    #Save model's checkpoint after every epoch. To be able to resume training if connection is lost.
    torch.save({
    'score':last_favg,
    'step': step,
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'scheduler': scheduler.state_dict(),
    'optimizer_state_dict': optimizer.state_dict()
    }, config.PATH_TO_SAVE + '/stance-checkpoint-semEval-CV1.pth')

end_time = time.time() - start_time
print('Elapsed time: {}'.format(end_time))


In [None]:
### load model checkpoints to evaluate or resume training
optimizer = torch.optim.AdamW(model.parameters(), lr=config.LR)
num_warmup_steps = int(config.WARMUP_PROP * config.EPOCHS * (len(custom_dataloader)/3) )
num_training_steps = config.EPOCHS * (len(custom_dataloader)/3)
scheduler = transformers.get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps)

mm = torch.load(config.PATH_TO_SAVE + '/stance-best-semEval-CV1.pth')
model.load_state_dict(mm['model_state_dict'], strict = False)
optimizer.load_state_dict(mm['optimizer_state_dict'])
scheduler.load_state_dict(mm['scheduler'])