In [38]:
import os 
from importlib import reload  
import numpy as np
import torch,time
from transformers import BartModel,BartConfig,BartForConditionalGeneration,BartForCausalLM
from tqdm.notebook import tqdm

In [39]:
import sys
  
# setting path to enable import from the parent directory
sys.path.append('../')

In [40]:
class EarlyStopping(object):
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, score_at_min1=0,patience=100, verbose=False, delta=0, path='checkpoint.pt',
                 trace_func=print,save_epochwise=False):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement.
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = score_at_min1
        self.early_stop = False
        self.delta = delta
        self.path = path
        self.trace_func = trace_func
        self.state_dict_list=[None]*patience
        self.improved=0
        self.stop_update=0
        self.save_model_counter=0
        self.save_epochwise=save_epochwise
        self.times_improved=0
        self.activated=False
    def activate(self,s1,s2):
        if not self.activated and s1>0 and s2>0: self.activated=True
    def __call__(self, score, epoch,model):
        if not self.activated: return None
        self.save_model_counter = (self.save_model_counter + 1) % 4
        if not self.stop_update:
            if self.verbose:
                self.trace_func(f'\033[91m The val score  of epoch {epoch} is {score:.4f} \033[0m')
            if score < self.best_score + self.delta:
                self.counter += 1
                self.trace_func(f'\033[93m EarlyStopping counter: {self.counter} out of {self.patience} \033[0m')
                if self.counter >= self.patience:
                    self.early_stop = True
                self.improved=0
            else:
                self.save_checkpoint(score, model,epoch)
                self.best_score = score
                self.counter = 0
                self.improved=1
        else:
            self.improved=0 #not needed though

    def save_checkpoint(self, score, model,epoch):
        '''Saves model when validation loss decrease.'''
        # if self.verbose:
        self.times_improved+=1
        self.trace_func(f'\033[92m Validation score improved ({self.best_score:.4f} --> {score:.4f}). \033[0m')
        if self.save_epochwise:
            path=self.path+"_"+str(self.times_improved)+"_"+str(epoch)
        else:
            path=self.path
        torch.save(model.state_dict(), path)

In [41]:
from preprocess import make_dataset
import pathlib
train=make_dataset(pathlib.Path("../data/protechn_corpus_eval/train/"))
val=make_dataset(pathlib.Path("../data/protechn_corpus_eval/dev/"))
test=make_dataset(pathlib.Path("../data/protechn_corpus_eval/test/"))

In [42]:
from transformers import BartTokenizer
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')

In [43]:
from preprocess import make_bert_dataset,make_bert_testset
train_=make_bert_testset(train)
val_=make_bert_testset(val)
test_=make_bert_testset(test)
train_sents=[ list(map(lambda x: x[1] if x[0]==0 else " "+x[1], enumerate(i))) for d in train_[0] for i in d]
val_sents=[ list(map(lambda x: x[1] if x[0]==0 else " "+x[1], enumerate(i))) for d in val_[0] for i in d]
test_sents=[ list(map(lambda x: x[1] if x[0]==0 else " "+x[1], enumerate(i))) for d in test_[0] for i in d]
def create_labels(dataset):
    temp=[ set(i)-set("O") for d in dataset[1] for i in d]
    return [ next(iter(i)) if len(i)>0 else "O"  for i in temp]
train_ls=create_labels(train_)
val_ls=create_labels(val_)
test_ls=create_labels(test_)
train_y_txt=[ i for d in train_[1] for i in d]
val_y_txt=[ i for d in val_[1] for i in d]
test_y_txt=[ i for d in test_[1] for i in d]

In [44]:
labels_set={'Appeal_to_Authority',
 'Appeal_to_fear-prejudice',
 'Bandwagon',
 'Black-and-White_Fallacy',
 'Causal_Oversimplification',
 'Doubt',
 'Exaggeration,Minimisation',
 'Flag-Waving',
 'Loaded_Language',
 'Name_Calling,Labeling',
 'O',
 'Obfuscation,Intentional_Vagueness,Confusion',
 'Red_Herring',
 'Reductio_ad_hitlerum',
 'Repetition',
 'Slogans',
 'Straw_Men',
 'Thought-terminating_Cliches',
 'Whataboutism'}
train_idx_bylabel={x: [i for i in range(len(train_ls)) if train_ls[i]==x] for x in labels_set} 
val_idx_bylabel={x: [i for i in range(len(val_ls)) if val_ls[i]==x] for x in labels_set} 
test_idx_bylabel={x: [i for i in range(len(test_ls)) if test_ls[i]==x] for x in labels_set} 

In [45]:
class BinaryClassDataset(torch.utils.data.Dataset):
    def __init__(self, x,y,y_txt,it_is_train=1,pos_or_neg=None,fix_seq_len=256,balance=False,
                 specific_label=None,for_protos=False):
        self.x=[]
        self.attn_mask=[]
        self.labels_mask=[]
        self.y_txt=[]
        self.y=[]
        self.labels_ids={}
        for i in labels_set:
            self.labels_ids[i]=len(self.labels_ids)
        self.y_fine_int=[]
        it_is_train_proxy=it_is_train
        for split_sent,y_tags,y_sent in zip(x,y_txt,y):
            if specific_label is not None and specific_label!=y_sent: continue
            if pos_or_neg=="pos" and y_sent=="O": continue
            elif pos_or_neg=="neg" and y_sent!="O": continue                
            if y_sent=="O":
                it_is_train=0
            else:
                it_is_train=it_is_train_proxy               
            tmp=tokenizer(split_sent,is_split_into_words=False)["input_ids"]
            tmp_x=[]
            tmp_attn=[]
            tmp_y=[]
            for i,chunk in enumerate(tmp):
                if for_protos and y_tags[i]=="O":
                    continue
                tmp_y.extend([y_tags[i]]*len(chunk))
                if y_tags[i]!="O":
                    mask=1
                else:
                    if it_is_train:
                        mask=0
                    else:
                        mask=1
                tmp_x.extend(chunk[1:-1])
                tmp_attn.extend([mask]*(len(chunk)-2))
            tmp_x.append(tokenizer.eos_token_id)
            tmp_x.insert(0,tokenizer.bos_token_id)
            tmp_attn.append(tmp_attn[-1])
            tmp_attn.insert(0,tmp_attn[0])
            self.x.append(tmp_x)
            self.attn_mask.append(tmp_attn)
            self.y_txt.append(tmp_y)
            self.y.append(1 if y_sent!="O" else 0)
            self.y_fine_int.append(self.labels_ids[y_sent])
        for tokid_sent in self.x:
            tokid_sent.extend([tokenizer.pad_token_id]*(fix_seq_len-len(tokid_sent)))
        for mask_vec in self.attn_mask:
            mask_vec.extend([0]*(fix_seq_len-len(mask_vec)))
        if balance:
            num_pos=np.sum(self.y)
            assert num_pos<len(self.y_fine_int)//2
            
            pos_indices=np.random.choice([i for i in range(len(self.y)) if self.y[i]==1],
                                         size=len(self.y)-2*num_pos,replace=True)
            self.x.extend([self.x[i] for i in pos_indices])
            self.y.extend([1 for i in pos_indices])
            self.y_fine_int.extend([self.y_fine_int[i] for i in pos_indices])
            self.attn_mask.extend([self.attn_mask[i] for i in pos_indices])
        self.fix_seq_len=fix_seq_len
    def __len__(self):
        return len(self.x)
    def __getitem__(self, idx):
        return self.x[idx],self.attn_mask[idx],self.y[idx]
    def collate_fn(self,batch):        
        return (torch.LongTensor([i[0] for i in batch]),
                torch.Tensor([i[1] for i in batch]),
                torch.LongTensor([i[2] for i in batch]))




In [46]:
train_dataset=BinaryClassDataset(train_sents,train_ls,train_y_txt,it_is_train=0,balance=True)
val_dataset=BinaryClassDataset(val_sents,val_ls,val_y_txt,it_is_train=0)
test_dataset=BinaryClassDataset(test_sents,test_ls,test_y_txt,it_is_train=0)

In [47]:
train_dl=torch.utils.data.DataLoader(train_dataset,batch_size=20,shuffle=True,
                                     collate_fn=train_dataset.collate_fn)
val_dl=torch.utils.data.DataLoader(val_dataset,batch_size=128,shuffle=False,
                                     collate_fn=val_dataset.collate_fn)
test_dl=torch.utils.data.DataLoader(test_dataset,batch_size=128,shuffle=False,
                                     collate_fn=test_dataset.collate_fn)

In [48]:
def print_logs(file,info,epoch,val_loss,mac_val_prec,mac_val_rec,mac_val_f1,mic_val_prec,mic_val_rec,mic_val_f1):
    logs=[]
    s=" ".join((info+" epoch",str(epoch),"Total loss %.4f"%(val_loss),"\n"))
    logs.append(s)
    print(s)
    s=" ".join((info+" epoch",str(epoch),"Prec",str(mac_val_prec),"\n"))
    logs.append(s)
    print(s)
    s=" ".join((info+" epoch",str(epoch),"Recall",str(mac_val_rec),"\n"))
    logs.append(s)
    print(s)
    s=" ".join((info+" epoch",str(epoch),"F1",str(mac_val_f1),"\n"))
    logs.append(s)
    print(s)
#     print("epoch",epoch,"MICRO val precision %.4f, recall %.4f, f1 %.4f,"%(mic_val_prec,mic_val_rec,mic_val_f1))
    print() 
    logs.append("\n")
    f=open(file,"a")
    f.writelines(logs)
    f.close()

In [49]:
from sklearn.metrics import precision_recall_fscore_support
def evaluate(dl,model_new=None,path=None,modelclass=None):
    assert (model_new is not None) ^ (path is not None)
    if path is not None:
        model_new=modelclass().cuda()
        model_new.load_state_dict(torch.load(path))
    loader = tqdm(dl, total=len(dl), unit="batches")
    total_len=0
    model_new.eval()    
    with torch.no_grad():
        total_loss=0
        tts=0
        y_pred=[]
        y_true=[]
        for batch in loader:
            input_ids,attn_mask,y=batch
            classfn_out,loss=model_new(input_ids,attn_mask,y,use_decoder=False,use_classfn=1)
            if classfn_out.ndim==1:
                predict=torch.zeros_like(y)
                predict[classfn_out>0]=1
            else:
                predict=torch.argmax(classfn_out,dim=1)
                
            y_pred.append(predict.cpu().numpy())
            y_true.append(y.cpu().numpy())
            total_loss+=(len(input_ids)*loss[0].item())
            total_len+=len(input_ids)
        total_loss=total_loss/total_len
        mac_prec,mac_recall,mac_f1_score,_=precision_recall_fscore_support(np.concatenate(y_true),np.concatenate(y_pred),labels=[0,1])
        mic_prec,mic_recall,mic_f1_score,_=0,0,0,0

    return total_loss,mac_prec,mac_recall,mac_f1_score,mic_prec,mic_recall,mic_f1_score

In [50]:
num_prototypes=20
num_pos_protos=19

In [51]:
class SimpleProtoTex(torch.nn.Module):
    def __init__(self,n_classes=2):
        super().__init__()
        self.bart_model=BartForConditionalGeneration.from_pretrained('facebook/bart-large')   
        self.bart_out_dim=self.bart_model.config.d_model
        self.max_position_embeddings=256
        self.num_protos=num_prototypes
        self.prototypes=torch.nn.Parameter(torch.rand(self.num_protos,self.max_position_embeddings,self.bart_out_dim))
        self.classfn_model=torch.nn.Linear(self.num_protos,2)
        self.loss_fn=torch.nn.CrossEntropyLoss(reduction="mean")
        
        self.set_encoder_status(True)
        self.set_decoder_status(False)
        self.set_protos_status(False)
        self.set_classfn_status(False)
        
        self.BNLayer=torch.nn.BatchNorm1d(self.num_protos)
        
    def set_encoder_status(self,status=True):
        self.num_enc_layers=len(self.bart_model.base_model.encoder.layers)
        for (i,x) in enumerate(self.bart_model.base_model.encoder.layers):
            requires_grad=False
            if i==self.num_enc_layers-1: requires_grad=status
            for y in x.parameters():
                y.requires_grad=requires_grad
    def set_decoder_status(self,status=True):
        self.num_dec_layers=len(self.bart_model.base_model.decoder.layers)
        for (i,x) in enumerate(self.bart_model.base_model.decoder.layers):
            requires_grad=False
            if i==self.num_dec_layers-1: requires_grad=status
            for y in x.parameters():
                y.requires_grad=requires_grad
    def set_classfn_status(self,status=True):
        self.classfn_model.requires_grad=status
    def set_protos_status(self,status=True):
        self.prototypes.requires_grad=status       
        

    def forward(self,input_ids,attn_mask,y,use_decoder=1,use_classfn=0,use_rc=0,use_p1=0,use_p2=0,rc_loss_lamb=0.95,p1_lamb=0.93,p2_lamb=0.92):
        batch_size=input_ids.size(0)
        if use_decoder:
            labels=input_ids.cuda()+0 
            labels[labels==self.bart_model.config.pad_token_id]=-100
            bart_output=self.bart_model(labels,attn_mask.cuda(),labels=labels,
                                        output_attentions=False,output_hidden_states=False)
            rc_loss,last_hidden_state=batch_size*bart_output.loss,bart_output.encoder_last_hidden_state
        else:
            rc_loss=0
            last_hidden_state=self.bart_model.base_model.encoder(input_ids.cuda(),attn_mask.cuda(),
                                                                 output_attentions=False,
                                                                 output_hidden_states=False).last_hidden_state
        input_for_classfn,l_p1,l_p2,classfn_out,classfn_loss=None,0,0,None,0
        if use_classfn or use_p1 or use_p2:
            input_for_classfn=torch.cdist(last_hidden_state.view(batch_size,-1),
                                          self.prototypes.view(self.num_protos,-1))
        if use_p1:
            l_p1=torch.mean(torch.min(input_for_classfn,dim=0)[0])
        if use_p2:            
            l_p2=torch.mean(torch.min(input_for_classfn,dim=1)[0])
        if use_classfn:
            classfn_out=self.classfn_model(input_for_classfn).view(batch_size,2)
            classfn_loss=self.loss_fn(classfn_out,y.cuda())
        if not use_rc:
            rc_loss=0
        total_loss=classfn_loss+rc_loss_lamb*rc_loss+p1_lamb*l_p1+p2_lamb*l_p2
        # return classfn_out,total_loss 
        return classfn_out, (total_loss, classfn_loss.detach().cpu(), rc_loss, l_p1,
                             l_p2)  

In [52]:
torch.cuda.empty_cache()        
model=SimpleProtoTex().cuda()
torch.cuda.empty_cache()

modelname="apr_24_simpleprotobart_onlyclass_80_20_train_nomask_protos_yesmask_enc_on"
save_path="../Models/"+modelname
logs_path="../Logs/"+modelname

In [None]:
"""
SimpleProtoTEx
"""
from transformers.optimization import AdamW
# optim=torch.optim.Adam(model.parameters(),lr=5e-5,weight_decay=0.01)
optim=AdamW(model.parameters(),lr=3e-5,weight_decay=0.01,eps=1e-8)
f=open(logs_path,"w")
f.writelines([""])
f.close()
epoch=-1
val_loss,mac_val_prec,mac_val_rec,mac_val_f1,mic_val_prec,mic_val_rec,mic_val_f1=evaluate(val_dl,model)
print_logs(logs_path,"VAL SCORES",epoch,val_loss,mac_val_prec,mac_val_rec,mac_val_f1,mic_val_prec,mic_val_rec,mic_val_f1)
val_loss,mac_val_prec,mac_val_rec,mac_val_f1,mic_val_prec,mic_val_rec,mic_val_f1=evaluate(train_dl,model)
print_logs(logs_path,"TRAIN SCORES",epoch,val_loss,mac_val_prec,mac_val_rec,mac_val_f1,mic_val_prec,mic_val_rec,mic_val_f1)
es=EarlyStopping(-np.inf,patience=7,path=save_path,save_epochwise=False)
n_iters=500
for epoch in range(n_iters):
    total_loss=0
    model.train()
    model.set_encoder_status(status=True)
    model.set_decoder_status(status=False)
    model.set_protos_status(status=True)
    model.set_classfn_status(status=True)
    classfn_loss,rc_loss,l_p1,l_p2,l_p3=[0]*5
    train_loader = tqdm(train_dl, total=len(train_dl), unit="batches",desc="training")
    for batch in train_loader:
        input_ids,attn_mask,y=batch
        classfn_out,loss=model(input_ids,attn_mask,y,use_decoder=0,use_classfn=1,
                               use_rc=0,use_p1=1,use_p2=1,rc_loss_lamb=1.0,p1_lamb=1.0,
                               p2_lamb=1.0)
        optim.zero_grad()
        loss[0].backward()
        optim.step()
        classfn_out=None
        loss=None
    total_loss=total_loss/len(train_dataset)
    val_loss,mac_val_prec,mac_val_rec,mac_val_f1,mic_val_prec,mic_val_rec,mic_val_f1=evaluate(train_dl,model)
    print_logs(logs_path,"TRAIN SCORES",epoch,val_loss,mac_val_prec,mac_val_rec,mac_val_f1,mic_val_prec,mic_val_rec,mic_val_f1)
    es.activate(mac_val_f1[0],mac_val_f1[1])
    val_loss,mac_val_prec,mac_val_rec,mac_val_f1,mic_val_prec,mic_val_rec,mic_val_f1=evaluate(val_dl,model)
    print_logs(logs_path,"VAL SCORES",epoch,val_loss,mac_val_prec,mac_val_rec,mac_val_f1,mic_val_prec,mic_val_rec,mic_val_f1)
    es((mac_val_f1[1]+mac_val_f1[0])/2,epoch,model)
    if es.early_stop:
        break
    if es.improved:
        """
        Below using "val_" prefix but the dl is that of test.
        """
        val_loss,mac_val_prec,mac_val_rec,mac_val_f1,mic_val_prec,mic_val_rec,mic_val_f1=evaluate(test_dl,model)
        print_logs(logs_path,"TEST SCORES",epoch,val_loss,mac_val_prec,mac_val_rec,mac_val_f1,mic_val_prec,mic_val_rec,mic_val_f1)
    elif (epoch+1)%5==0:
        """
        Below using "val_" prefix but the dl is that of test.
        """
        val_loss,mac_val_prec,mac_val_rec,mac_val_f1,mic_val_prec,mic_val_rec,mic_val_f1=evaluate(test_dl,model)
        print_logs(logs_path,"TEST SCORES (not the best ones)",epoch,val_loss,mac_val_prec,mac_val_rec,mac_val_f1,mic_val_prec,mic_val_rec,mic_val_f1)





  0%|          | 0/16 [00:00<?, ?batches/s]

VAL SCORES epoch -1 Total loss 57.4708 

VAL SCORES epoch -1 Prec [0.66666667 0.        ] 

VAL SCORES epoch -1 Recall [1. 0.] 

VAL SCORES epoch -1 F1 [0.8 0. ] 




  _warn_prf(average, modifier, msg_start, len(result))


  0%|          | 0/1021 [00:00<?, ?batches/s]

TRAIN SCORES epoch -1 Total loss 86.2612 

TRAIN SCORES epoch -1 Prec [0.5 0. ] 

TRAIN SCORES epoch -1 Recall [1. 0.] 

TRAIN SCORES epoch -1 F1 [0.66666667 0.        ] 




  _warn_prf(average, modifier, msg_start, len(result))


training:   0%|          | 0/1021 [00:00<?, ?batches/s]

  0%|          | 0/1021 [00:00<?, ?batches/s]

TRAIN SCORES epoch 0 Total loss 0.6963 

TRAIN SCORES epoch 0 Prec [0.50823896 0.58589815] 

TRAIN SCORES epoch 0 Recall [0.92751494 0.10255657] 

TRAIN SCORES epoch 0 F1 [0.65665742 0.17455819] 




  0%|          | 0/16 [00:00<?, ?batches/s]

VAL SCORES epoch 0 Total loss 0.6770 

VAL SCORES epoch 0 Prec [0.67310878 0.38862559] 

VAL SCORES epoch 0 Recall [0.90430267 0.12166172] 

VAL SCORES epoch 0 F1 [0.77176322 0.18531073] 


[92m Validation score improved (-inf --> 0.4785). [0m


  0%|          | 0/32 [00:00<?, ?batches/s]

TEST SCORES epoch 0 Total loss 0.6672 

TEST SCORES epoch 0 Prec [0.74040067 0.31619537] 

TEST SCORES epoch 0 Recall [0.90912197 0.11647727] 

TEST SCORES epoch 0 F1 [0.8161325  0.17024221] 




training:   0%|          | 0/1021 [00:00<?, ?batches/s]

  0%|          | 0/1021 [00:00<?, ?batches/s]

TRAIN SCORES epoch 1 Total loss 0.6959 

TRAIN SCORES epoch 1 Prec [0.  0.5] 

TRAIN SCORES epoch 1 Recall [0. 1.] 

TRAIN SCORES epoch 1 F1 [0.         0.66666667] 




  _warn_prf(average, modifier, msg_start, len(result))


  0%|          | 0/16 [00:00<?, ?batches/s]

VAL SCORES epoch 1 Total loss 0.7143 

VAL SCORES epoch 1 Prec [0.         0.33333333] 

VAL SCORES epoch 1 Recall [0. 1.] 

VAL SCORES epoch 1 F1 [0.  0.5] 


[93m EarlyStopping counter: 1 out of 7 [0m


  _warn_prf(average, modifier, msg_start, len(result))


training:   0%|          | 0/1021 [00:00<?, ?batches/s]

  0%|          | 0/1021 [00:00<?, ?batches/s]

TRAIN SCORES epoch 2 Total loss 0.7029 

TRAIN SCORES epoch 2 Prec [0.5 0. ] 

TRAIN SCORES epoch 2 Recall [1. 0.] 

TRAIN SCORES epoch 2 F1 [0.66666667 0.        ] 




  _warn_prf(average, modifier, msg_start, len(result))


  0%|          | 0/16 [00:00<?, ?batches/s]

VAL SCORES epoch 2 Total loss 0.6560 

VAL SCORES epoch 2 Prec [0.66666667 0.        ] 

VAL SCORES epoch 2 Recall [1. 0.] 

VAL SCORES epoch 2 F1 [0.8 0. ] 


[93m EarlyStopping counter: 2 out of 7 [0m


  _warn_prf(average, modifier, msg_start, len(result))


training:   0%|          | 0/1021 [00:00<?, ?batches/s]

  0%|          | 0/1021 [00:00<?, ?batches/s]

TRAIN SCORES epoch 3 Total loss 0.7005 

TRAIN SCORES epoch 3 Prec [0.49936168 0.25      ] 

TRAIN SCORES epoch 3 Recall [0.99617984 0.00127339] 

TRAIN SCORES epoch 3 F1 [0.66524939 0.00253387] 




  0%|          | 0/16 [00:00<?, ?batches/s]

VAL SCORES epoch 3 Total loss 0.6592 

VAL SCORES epoch 3 Prec [0.66666667 0.        ] 

VAL SCORES epoch 3 Recall [1. 0.] 

VAL SCORES epoch 3 F1 [0.8 0. ] 


[93m EarlyStopping counter: 3 out of 7 [0m


  _warn_prf(average, modifier, msg_start, len(result))


training:   0%|          | 0/1021 [00:00<?, ?batches/s]

  0%|          | 0/1021 [00:00<?, ?batches/s]

TRAIN SCORES epoch 4 Total loss 0.6979 

TRAIN SCORES epoch 4 Prec [0.48309889 0.46059716] 

TRAIN SCORES epoch 4 Recall [0.67616809 0.27652072] 

TRAIN SCORES epoch 4 F1 [0.56355621 0.34557473] 




  0%|          | 0/16 [00:00<?, ?batches/s]

VAL SCORES epoch 4 Total loss 0.7008 

VAL SCORES epoch 4 Prec [0.64687741 0.29793103] 

VAL SCORES epoch 4 Recall [0.62240356 0.32047478] 

VAL SCORES epoch 4 F1 [0.63440454 0.30879199] 


[93m EarlyStopping counter: 4 out of 7 [0m


  0%|          | 0/32 [00:00<?, ?batches/s]

TEST SCORES (not the best ones) epoch 4 Total loss 0.6941 

TEST SCORES (not the best ones) epoch 4 Prec [0.72050147 0.23446105] 

TEST SCORES (not the best ones) epoch 4 Recall [0.66757772 0.28219697] 

TEST SCORES (not the best ones) epoch 4 F1 [0.69303068 0.25612376] 




training:   0%|          | 0/1021 [00:00<?, ?batches/s]

  0%|          | 0/1021 [00:00<?, ?batches/s]

TRAIN SCORES epoch 5 Total loss 0.7161 

TRAIN SCORES epoch 5 Prec [0.47623369 0.34971306] 

TRAIN SCORES epoch 5 Recall [0.8224116  0.09550397] 

TRAIN SCORES epoch 5 F1 [0.60318259 0.15003462] 




  0%|          | 0/16 [00:00<?, ?batches/s]

VAL SCORES epoch 5 Total loss 0.6713 

VAL SCORES epoch 5 Prec [0.63442136 0.17210682] 

VAL SCORES epoch 5 Recall [0.79302671 0.08605341] 

VAL SCORES epoch 5 F1 [0.70491263 0.11473788] 


[93m EarlyStopping counter: 5 out of 7 [0m


training:   0%|          | 0/1021 [00:00<?, ?batches/s]

  0%|          | 0/1021 [00:00<?, ?batches/s]

TRAIN SCORES epoch 6 Total loss 0.7517 

TRAIN SCORES epoch 6 Prec [0.47734975 0.35224761] 

TRAIN SCORES epoch 6 Recall [0.827799   0.09364286] 

TRAIN SCORES epoch 6 F1 [0.60552431 0.14795326] 




  0%|          | 0/16 [00:00<?, ?batches/s]

VAL SCORES epoch 6 Total loss 0.6716 

VAL SCORES epoch 6 Prec [0.63636364 0.16339869] 

VAL SCORES epoch 6 Recall [0.81008902 0.07418398] 

VAL SCORES epoch 6 F1 [0.71279373 0.10204082] 


[93m EarlyStopping counter: 6 out of 7 [0m


training:   0%|          | 0/1021 [00:00<?, ?batches/s]

  0%|          | 0/1021 [00:00<?, ?batches/s]

TRAIN SCORES epoch 7 Total loss 0.6900 

TRAIN SCORES epoch 7 Prec [0.61437411 0.62120448] 

TRAIN SCORES epoch 7 Recall [0.63218729 0.60319326] 

TRAIN SCORES epoch 7 F1 [0.62315342 0.61206639] 




  0%|          | 0/16 [00:00<?, ?batches/s]

VAL SCORES epoch 7 Total loss 0.7222 

VAL SCORES epoch 7 Prec [0.73802817 0.41274817] 

VAL SCORES epoch 7 Recall [0.58308605 0.58605341] 

VAL SCORES epoch 7 F1 [0.6514712  0.48436542] 


[92m Validation score improved (0.4785 --> 0.5679). [0m


  0%|          | 0/32 [00:00<?, ?batches/s]

TEST SCORES epoch 7 Total loss 0.7038 

TEST SCORES epoch 7 Prec [0.78436874 0.34811828] 

TEST SCORES epoch 7 Recall [0.66860266 0.4905303 ] 

TEST SCORES epoch 7 F1 [0.72187385 0.4072327 ] 




training:   0%|          | 0/1021 [00:00<?, ?batches/s]

  0%|          | 0/1021 [00:00<?, ?batches/s]

TRAIN SCORES epoch 8 Total loss 0.7215 

TRAIN SCORES epoch 8 Prec [0.48420803 0.42557359] 

TRAIN SCORES epoch 8 Recall [0.79890293 0.14898619] 

TRAIN SCORES epoch 8 F1 [0.60296455 0.22070667] 




  0%|          | 0/16 [00:00<?, ?batches/s]

VAL SCORES epoch 8 Total loss 0.6934 

VAL SCORES epoch 8 Prec [0.63718958 0.20215633] 

VAL SCORES epoch 8 Recall [0.78041543 0.11127596] 

VAL SCORES epoch 8 F1 [0.70156719 0.14354067] 


[93m EarlyStopping counter: 1 out of 7 [0m


training:   0%|          | 0/1021 [00:00<?, ?batches/s]

  0%|          | 0/1021 [00:00<?, ?batches/s]

TRAIN SCORES epoch 9 Total loss 0.7232 

TRAIN SCORES epoch 9 Prec [0.49118943 0.46908127] 

TRAIN SCORES epoch 9 Recall [0.7645215  0.20805172] 

TRAIN SCORES epoch 9 F1 [0.59810721 0.28825405] 




  0%|          | 0/16 [00:00<?, ?batches/s]

VAL SCORES epoch 9 Total loss 0.7160 

VAL SCORES epoch 9 Prec [0.63653724 0.22838137] 

VAL SCORES epoch 9 Recall [0.74183976 0.15281899] 

VAL SCORES epoch 9 F1 [0.68516615 0.18311111] 


[93m EarlyStopping counter: 2 out of 7 [0m


  0%|          | 0/32 [00:00<?, ?batches/s]

TEST SCORES (not the best ones) epoch 9 Total loss 0.6724 

TEST SCORES (not the best ones) epoch 9 Prec [0.71732234 0.19855596] 

TEST SCORES (not the best ones) epoch 9 Recall [0.77246327 0.15625   ] 

TEST SCORES (not the best ones) epoch 9 F1 [0.74387235 0.17488076] 




training:   0%|          | 0/1021 [00:00<?, ?batches/s]

  0%|          | 0/1021 [00:00<?, ?batches/s]

TRAIN SCORES epoch 10 Total loss 0.7094 

TRAIN SCORES epoch 10 Prec [0.50364378 0.50997986] 

TRAIN SCORES epoch 10 Recall [0.73787834 0.27279851] 

TRAIN SCORES epoch 10 F1 [0.59866487 0.35545629] 




  0%|          | 0/16 [00:00<?, ?batches/s]

VAL SCORES epoch 10 Total loss 0.7344 

VAL SCORES epoch 10 Prec [0.66761769 0.33548387] 

VAL SCORES epoch 10 Recall [0.69436202 0.30860534] 

VAL SCORES epoch 10 F1 [0.68072727 0.32148377] 


[93m EarlyStopping counter: 3 out of 7 [0m


training:   0%|          | 0/1021 [00:00<?, ?batches/s]

  0%|          | 0/1021 [00:00<?, ?batches/s]

TRAIN SCORES epoch 11 Total loss 0.7021 

TRAIN SCORES epoch 11 Prec [0.62357797 0.61707201] 

TRAIN SCORES epoch 11 Recall [0.60671956 0.63375453] 

TRAIN SCORES epoch 11 F1 [0.61503326 0.62530202] 




  0%|          | 0/16 [00:00<?, ?batches/s]

VAL SCORES epoch 11 Total loss 0.7631 

VAL SCORES epoch 11 Prec [0.73382046 0.39379699] 

VAL SCORES epoch 11 Recall [0.52151335 0.62166172] 

VAL SCORES epoch 11 F1 [0.60971379 0.48216341] 


[93m EarlyStopping counter: 4 out of 7 [0m


training:   0%|          | 0/1021 [00:00<?, ?batches/s]

  0%|          | 0/1021 [00:00<?, ?batches/s]

TRAIN SCORES epoch 12 Total loss 0.6960 

TRAIN SCORES epoch 12 Prec [0.76124409 0.62363807] 

TRAIN SCORES epoch 12 Recall [0.48907826 0.84660594] 

TRAIN SCORES epoch 12 F1 [0.59553912 0.71821506] 




  0%|          | 0/16 [00:00<?, ?batches/s]

VAL SCORES epoch 12 Total loss 0.7238 

VAL SCORES epoch 12 Prec [0.70542636 0.38559814] 

VAL SCORES epoch 12 Recall [0.60756677 0.4925816 ] 

VAL SCORES epoch 12 F1 [0.65284974 0.43257329] 


[93m EarlyStopping counter: 5 out of 7 [0m


training:   0%|          | 0/1021 [00:00<?, ?batches/s]

  0%|          | 0/1021 [00:00<?, ?batches/s]

TRAIN SCORES epoch 13 Total loss 0.6962 

TRAIN SCORES epoch 13 Prec [0.97916667 0.53119457] 

TRAIN SCORES epoch 13 Recall [0.11969831 0.99745323] 

TRAIN SCORES epoch 13 F1 [0.21331937 0.69321624] 




  0%|          | 0/16 [00:00<?, ?batches/s]

VAL SCORES epoch 13 Total loss 0.7801 

VAL SCORES epoch 13 Prec [0.84567901 0.36749117] 

VAL SCORES epoch 13 Recall [0.20326409 0.92581602] 

VAL SCORES epoch 13 F1 [0.3277512  0.52613828] 


[93m EarlyStopping counter: 6 out of 7 [0m


training:   0%|          | 0/1021 [00:00<?, ?batches/s]

  0%|          | 0/1021 [00:00<?, ?batches/s]

In [None]:
class ProtoTex(torch.nn.Module):
    def __init__(self,n_classes=2,bias=True,dropout=False,special_classfn=False,p=0.5,batchnormlp1=False):
        super().__init__()
        self.bart_model=BartForConditionalGeneration.from_pretrained('facebook/bart-large') 
        self.bart_out_dim=self.bart_model.config.d_model
        self.one_by_sqrt_bartoutdim=1/torch.sqrt(torch.tensor(self.bart_out_dim).float())
        self.max_position_embeddings=256
        self.num_protos=num_prototypes
        self.num_pos_protos=num_pos_protos
        self.num_neg_protos=self.num_protos-self.num_pos_protos
        self.pos_prototypes=torch.nn.Parameter(torch.rand(self.num_pos_protos,self.max_position_embeddings,self.bart_out_dim))
        self.neg_prototypes=torch.nn.Parameter(torch.rand(self.num_neg_protos,self.max_position_embeddings,self.bart_out_dim))
        self.classfn_model=torch.nn.Linear(self.num_protos,2,bias=bias)
        self.loss_fn=torch.nn.CrossEntropyLoss(reduction="mean")
        
        self.do_dropout=dropout
        self.special_classfn=special_classfn
        
        self.dropout=torch.nn.Dropout(p=p)
        self.dobatchnorm=batchnormlp1
        self.distance_grounder = torch.zeros(2, self.num_protos).cuda()
        self.distance_grounder[0][:self.num_pos_protos] = 1e7
        self.distance_grounder[1][self.num_pos_protos:] = 1e7

    
    def set_prototypes(self,do_random=False):
        if do_random:
            print("initializing prototypes with xavier init")
            torch.nn.init.xavier_normal_(self.pos_prototypes)
            torch.nn.init.xavier_normal_(self.neg_prototypes)
        else:
            print("initializing prototypes with encoded outputs")
            self.eval()
            with torch.no_grad():
                self.pos_prototypes=torch.nn.Parameter(
                    self.bart_model.base_model.encoder(input_ids_pos_rdm.cuda(),
                                                       attn_mask_pos_rdm.cuda(),
                                                       output_attentions=False,
                                                       output_hidden_states=False).last_hidden_state)
                self.neg_prototypes=torch.nn.Parameter(
                    self.bart_model.base_model.encoder(input_ids_neg_rdm.cuda(),
                                                       attn_mask_neg_rdm.cuda(),
                                                       output_attentions=False,
                                                       output_hidden_states=False).last_hidden_state)
    
    def set_shared_status(self,status=True):
        print("ALERT!!! Shared variable is shared by encoder_input_embeddings and decoder_input_embeddings")
        self.bart_model.model.shared.requires_grad_(status)

    def set_encoder_status(self,status=True):
        self.num_enc_layers=len(self.bart_model.base_model.encoder.layers)
        for i in range(self.num_enc_layers):
            self.bart_model.base_model.encoder.layers[i].requires_grad_(False)
        self.bart_model.base_model.encoder.layers[self.num_enc_layers-1].requires_grad_(status)
        return
    def set_decoder_status(self,status=True):
        self.num_dec_layers=len(self.bart_model.base_model.decoder.layers)
        for i in range(self.num_dec_layers):
            self.bart_model.base_model.decoder.layers[i].requires_grad_(False)
        self.bart_model.base_model.decoder.layers[self.num_dec_layers-1].requires_grad_(status)
        return
    def set_classfn_status(self,status=True):
        self.classfn_model.requires_grad_(status)

    def set_protos_status(self,pos_or_neg=None,status=True):
        if pos_or_neg=="pos" or pos_or_neg is None:
            self.pos_prototypes.requires_grad=status       
        if pos_or_neg=="neg" or pos_or_neg is None:
            self.neg_prototypes.requires_grad=status       
        

    def forward(self, input_ids, attn_mask, y, use_decoder=1, use_classfn=0, use_rc=0, use_p1=0, use_p2=0,
                use_p3=0, classfn_lamb=1.0, rc_loss_lamb=0.95, p1_lamb=0.93, p2_lamb=0.92, p3_lamb=1.0,
                distmask_lp1=0,distmask_lp2=0,
                pos_or_neg=None,random_mask_for_distanceMat=None):
        """
            1. p3_loss is the prototype-distance-maximising loss. See the set of lines after the line "if use_p3:"
            2. We also have flags distmask_lp1 and distmask_lp2 which uses "masked" distance matrix for calculating lp1 and lp2 loss.
            3. the flag "random_mask_for_distanceMat" is an experimental part. It randomly masks (artificially inflates) 
            random places in the distance matrix so as to encourage more prototypes be "discovered" by the training 
            examples.
        """
        batch_size = input_ids.size(0)
        if use_decoder:
            labels = input_ids.cuda() + 0
            labels[labels == self.bart_model.config.pad_token_id] = -100
            bart_output = self.bart_model(input_ids.cuda(), attn_mask.cuda(), labels=labels,
                                          output_attentions=False, output_hidden_states=False)
            rc_loss, last_hidden_state = bart_output.loss, bart_output.encoder_last_hidden_state
        else:
            rc_loss = torch.tensor(0)
            last_hidden_state = self.bart_model.base_model.encoder(input_ids.cuda(), attn_mask.cuda(),
                                                                   output_attentions=False,
                                                                   output_hidden_states=False).last_hidden_state
        input_for_classfn, l_p1, l_p2, l_p3, l_p4, classfn_out, classfn_loss = (None, torch.tensor(0), torch.tensor(0),
                                                                                torch.tensor(0), torch.tensor(0), None,
                                                                                torch.tensor(0))
        if use_classfn or use_p1 or use_p2 or use_p3:
            all_protos = torch.cat((self.pos_prototypes, self.neg_prototypes), dim=0)
            if use_classfn or use_p1 or use_p2:
                if not self.dobatchnorm:
                    input_for_classfn = self.one_by_sqrt_bartoutdim * torch.cdist(last_hidden_state.view(batch_size, -1),
                                                                                  all_protos.view(self.num_protos, -1))
                else:
                    input_for_classfn = torch.cdist(last_hidden_state.view(batch_size, -1),
                                                    all_protos.view(self.num_protos, -1))
                    input_for_classfn= torch.nn.functional.instance_norm(
                        input_for_classfn.view(batch_size,1,self.num_protos)).view(batch_size,
                                                                                   self.num_protos)
            if use_p1 or use_p2:
                distance_mask = self.distance_grounder[y.cuda()]
                input_for_classfn_masked = input_for_classfn+distance_mask
                if random_mask_for_distanceMat:
                    random_mask=torch.bernoulli(torch.ones_like(input_for_classfn_masked)*
                                                random_mask_for_distanceMat).bool()
                    input_for_classfn_masked[random_mask]=1e7
        if use_p1:
            l_p1 = torch.mean(torch.min(input_for_classfn_masked if distmask_lp1 else input_for_classfn, dim=0)[0])
        if use_p2:
            l_p2 = torch.mean(torch.min(input_for_classfn_masked if distmask_lp2 else input_for_classfn, dim=1)[0])
        if use_p3:
            l_p3 = self.one_by_sqrt_bartoutdim * torch.mean(torch.pdist(
                self.pos_prototypes.view(self.num_pos_protos,-1)))
        if use_classfn:
            if self.do_dropout:
                if self.special_classfn:
                    classfn_out = (input_for_classfn@self.classfn_model.weight.t()+
                                   self.dropout(self.classfn_model.bias.repeat(batch_size,1))).view(batch_size, 2)
                else:
                    classfn_out = self.classfn_model(self.dropout(input_for_classfn)).view(batch_size, 2)
            else:
                classfn_out = self.classfn_model(input_for_classfn).view(batch_size, 2)
            classfn_loss = self.loss_fn(classfn_out, y.cuda())
        if not use_rc:
            rc_loss = torch.tensor(0)
        total_loss = classfn_lamb * classfn_loss + rc_loss_lamb * rc_loss + p1_lamb * l_p1 + p2_lamb * l_p2 - p3_lamb * l_p3
        return classfn_out, (total_loss, classfn_loss.detach().cpu(), rc_loss.detach().cpu(), l_p1.detach().cpu(),
                             l_p2.detach().cpu(), l_p3.detach().cpu())   

In [None]:
torch.cuda.empty_cache()
modelname="NegProtoTEx_protos_xavier_large_bs20_20_woRat_noReco_g2d_nobias_nodrop_cu1_PosUp_normed"
model=ProtoTex(bias=False,dropout=False,special_classfn=False,p=0.75,batchnormlp1=True).cuda()
model.set_prototypes(do_random=True)
torch.cuda.empty_cache()

save_path="../Models/"+modelname
logs_path="../Logs/"+modelname

In [None]:
"""
ProtoTEx Training
"""
from transformers.optimization import AdamW
optim=AdamW(model.parameters(),lr=3e-5,weight_decay=0.01,eps=1e-8)
f=open(logs_path,"w")
f.writelines([""])
f.close()
val_loss,mac_val_prec,mac_val_rec,mac_val_f1,mic_val_prec,mic_val_rec,mic_val_f1=evaluate(val_dl,model)
epoch=-1
print_logs(logs_path,"VAL SCORES",epoch,val_loss,mac_val_prec,mac_val_rec,mac_val_f1,mic_val_prec,mic_val_rec,mic_val_f1)
es=EarlyStopping(-np.inf,patience=7,path=save_path,save_epochwise=False)
n_iters=1000
gamma=2
delta=1
kappa=1
p1_lamb=0.9
p2_lamb=0.9
p3_lamb=0.9
for iter_ in range(n_iters):
    total_loss = 0
    """
    During Delta, We want decoder to become better at decoding the trained encoder
    and Prototypes to become closer to some encoded representation. And that's why it makes 
    sense to use l_p1 loss and not l_p2 loss.
    losses- rc_loss, l_p1 loss
    trainable- decoder and prototypes
    details- makes pos_prototypes closer to pos_egs and neg_protos closer to neg_egs 
    """
    model.train()
    model.set_encoder_status(status=False)
    model.set_decoder_status(status=False)
    model.set_protos_status(status=True)
    model.set_classfn_status(status=False)
    model.set_shared_status(status=True)

    for epoch in range(delta):
        train_loader = tqdm(train_dl, total=len(train_dl), unit="batches", desc="delta training")
        for batch in train_loader:
            input_ids, attn_mask, y = batch
            classfn_out, loss = model(input_ids, attn_mask, y, use_decoder=0, use_classfn=0,
                                      use_rc=0, use_p1=1, use_p2=0, use_p3=0,
                                      rc_loss_lamb=1.0, p1_lamb=p1_lamb, p2_lamb=p2_lamb,
                                      p3_lamb=p3_lamb,distmask_lp1=1,distmask_lp2=1,
                                      random_mask_for_distanceMat=None)
            optim.zero_grad()
            loss[0].backward()
            optim.step()
    """
    During gamma, we only want to improve the classification performance. Therefore we will
    improve encoder to become closer to the prototypes, at the same time also improving
    the classification accuracy. That's why encoder and classification layer must be trainabl
    together without segrregating pos and neg examples.
    Only Encoder and Classfn are trainable
    """
    model.train()
    model.set_encoder_status(status=True)
    model.set_decoder_status(status=False)
    model.set_protos_status(status=False)
    model.set_classfn_status(status=True)
    model.set_shared_status(status=True)

    for epoch in range(gamma):
        train_loader = tqdm(train_dl, total=len(train_dl), unit="batches", desc="gamma training")
        for batch in train_loader:
            input_ids, attn_mask, y = batch
            classfn_out, loss = model(input_ids, attn_mask, y, use_decoder=0, use_classfn=1,
                                      use_rc=0, use_p1=0, use_p2=1,
                                      rc_loss_lamb=1., p1_lamb=p1_lamb,p2_lamb=p2_lamb,
                                      distmask_lp1 = 1, distmask_lp2 = 1)
            optim.zero_grad()
            loss[0].backward()
            optim.step()

    val_loss,mac_val_prec,mac_val_rec,mac_val_f1,mic_val_prec,mic_val_rec,mic_val_f1=evaluate(train_dl,model)
    print_logs(logs_path,"TRAIN SCORES",iter_,val_loss,mac_val_prec,mac_val_rec,mac_val_f1,mic_val_prec,mic_val_rec,mic_val_f1)
    es.activate(mac_val_f1[0],mac_val_f1[1])

    val_loss,mac_val_prec,mac_val_rec,mac_val_f1,mic_val_prec,mic_val_rec,mic_val_f1=evaluate(val_dl,model)
    print_logs(logs_path,"VAL SCORES",iter_,val_loss,mac_val_prec,mac_val_rec,mac_val_f1,mic_val_prec,mic_val_rec,mic_val_f1)        
    es(0.5*(mac_val_f1[1]+mac_val_f1[0]),epoch,model)
    if es.early_stop:
        break
    if es.improved:
        """
        Below using "val_" prefix but the dl is that of test.
        """
        val_loss,mac_val_prec,mac_val_rec,mac_val_f1,mic_val_prec,mic_val_rec,mic_val_f1=evaluate(test_dl,model)
        print_logs(logs_path,"TEST SCORES",iter_,val_loss,mac_val_prec,mac_val_rec,mac_val_f1,mic_val_prec,mic_val_rec,mic_val_f1)
    elif (iter_+1)%5==0:
        """
        Below using "val_" prefix but the dl is that of test.
        """
        val_loss,mac_val_prec,mac_val_rec,mac_val_f1,mic_val_prec,mic_val_rec,mic_val_f1=evaluate(test_dl,model)
        print_logs(logs_path,"TEST SCORES (not the best ones)",iter_,val_loss,mac_val_prec,mac_val_rec,mac_val_f1,mic_val_prec,mic_val_rec,mic_val_f1)