In [None]:
import torch
import numpy as np
import pandas as pd

import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as T

import torchviz

from model import *
from utils import *
from dataset_loader import *
from config import * 


In [None]:
torch.manual_seed(0)
np.random.seed(0)
torch.set_printoptions(threshold=sys.maxsize)

In [None]:
USE_GPU = False

dtype = torch.float32 # we will be using float throughout this tutorial

if USE_GPU and torch.cuda.is_available():   # For Nvidia GPU -- cuda
    device = torch.device('cuda')
elif USE_GPU and torch.backends.mps.is_available(): # For MacOS Ver >=12.3 -- Metal
    device = torch.device("mps")
else:
    device = torch.device('cpu')



print('using device:', device)


In [None]:
print_every = 100   # Constant to control how frequently we print train loss
N = 2               # the repeating block number of ISML, should be >= 1
D = 75
hidden_size = 100
# emb_length = 200
batch_size = 8
W = 3
lamb_1 = 1
lamb_2 = 1
lamb_3 = 1
l2 = 1e-5
lr = 5e-3
dropout = 0.5

class SecondHalfNet(nn.Module):
    def __init__(self,N,D,hidden_size,device):
        super().__init__()

        self.bilstm_e_list = []
        self.bilstm_c_list = []
        self.fc_e_list = []
        self.fc_c_list = []
        for n in range(N):
            bilstm_e = nn.LSTM(input_size= hidden_size*2+4*n, hidden_size= hidden_size,\
                               num_layers= 1, batch_first=True,bidirectional=True)
            self.bilstm_e_list.append(bilstm_e)

            bilstm_c = nn.LSTM(input_size= hidden_size*2+4*n, hidden_size= hidden_size,\
                               num_layers= 1, batch_first=True,bidirectional=True)
            self.bilstm_c_list.append(bilstm_c)

            fc_e = nn.Linear(hidden_size*2,2)
            # nn.init.kaiming_normal_(fc_e.weight)
            self.fc_e_list.append(fc_e)

            fc_c = nn.Linear(hidden_size*2,2)
            # nn.init.kaiming_normal_(fc_c.weight)
            self.fc_c_list.append(fc_c)

        self.fc_cml = nn.Linear(hidden_size*2,D)
        self.fc_eml = nn.Linear(hidden_size*2,D)


    def forward(self, s1):
        # scores = None
        self.y_e_list = []
        self.y_c_list = []
        s_tmp = s1

        for n in range(N):
            # print(self.bilstm_e_list[n](s_tmp))
            e_lstm_out,_ = self.bilstm_e_list[n](s_tmp)
            y_e = nn.functional.softmax(self.fc_e_list[n](e_lstm_out),dim=2)
            self.y_e_list.append(y_e)

            c_lstm_out,_ = self.bilstm_c_list[n](s_tmp)
            y_c = nn.functional.softmax(self.fc_c_list[n](c_lstm_out),dim=2)
            self.y_c_list.append(y_c)

            s_tmp = torch.cat((s_tmp,y_e,y_c),dim=2)
            
            # print('s_tmp shape',s_tmp.shape)

        cml_scores = self.fc_cml(e_lstm_out)
        eml_scores = self.fc_eml(c_lstm_out)


        return self.y_e_list,self.y_c_list,s_tmp,cml_scores,eml_scores




In [None]:
def test_SecondHalfNet():
    s1 = torch.rand((batch_size, D, hidden_size*2), dtype=dtype,device=device,\
                    requires_grad=False) 

    model = SecondHalfNet(N,D,hidden_size,device)
    model = model.to(device=device)
    model = torch.compile(model)    # Pytorch 2.0 acceleration
    model.train()
    # model.eval()

    y_e_list,y_c_list,s_final,cml_scores,eml_scores = model(s1)
    print(len(y_e_list),len(y_c_list),y_e_list[0].shape,y_c_list[0].shape) 

    # torch.onnx.export(model, s1, 'secondhalf.onnx')
    
    # dots = torchviz.make_dot(s_tmp,params=dict(model.named_parameters()),show_attrs=False, show_saved=False)
    # dots.format = 'png'
    # dots.render('secondhalf_modelviz')
    
test_SecondHalfNet()

In [None]:
def input_padding(s1,len_target=D):  # D = 75 --> max doc length
    s1 = torch.nn.functional.pad(s1,(0,0,0,D-s1.shape[1]),value=0)
    return s1

In [None]:
def labelTransform(doc_couples_b):
    y_e_isml = torch.zeros(batch_size, D, 2)
    y_c_isml = torch.zeros(batch_size, D, 2)

    for i in range(batch_size):
        for emo,cau in doc_couples_b[i]:
            # print(emo,cau)
            y_e_isml[i][emo-1][0] = 1 # the True prob is col 0 and the False prob is col 1
            y_c_isml[i][cau-1][0] = 1

    y_e_isml.to(device=device)
    y_c_isml.to(device=device)


    y_cml_pairs = torch.zeros(batch_size,D,D)
    y_eml_pairs = torch.zeros(batch_size,D,D)
    for i in range(batch_size):
        for emo,cau in doc_couples_b[i]:
            y_cml_pairs[i][emo-1][cau-1] = 1
            y_eml_pairs[i][cau-1][emo-1] = 1

    y_cml_pairs.to(device=device)
    y_eml_pairs.to(device=device)

    return y_e_isml,y_c_isml,y_cml_pairs,y_eml_pairs

In [None]:
def slidingmask_gen():
    slidingmask = torch.ones(D,D)  
    slidingmask = torch.triu(slidingmask,diagonal=-W)  
    slidingmask = torch.tril(slidingmask,diagonal=W) 
    slidingmask = slidingmask.repeat(batch_size,1,1)
    slidingmask.to(device=device)
    # print(slidingmask)
    return slidingmask

# print(slidingmask_gen().shape)

In [None]:
def loss_calc(y_e_list,y_c_list,doc_couples_b,cml_scores,eml_scores,slidingmask):
    y_e_isml,y_c_isml,y_cml_pairs,y_eml_pairs = labelTransform(doc_couples_b)

    loss_isml = 0
    for n in range(N):  # can accelerate by n times with full vectorization
        # print(y_e_isml.shape,y_e_list[n].shape)
        loss_isml += -torch.sum(torch.mul(y_e_isml,torch.log(y_e_list[n])))\
                        -torch.sum(torch.mul(y_c_isml,torch.log(y_c_list[n])))
        
        
    
    cml_out_beforemask = torch.div(1,1+torch.exp(cml_scores))
    eml_out_beforemask = torch.div(1,1+torch.exp(eml_scores))
    loss_cmll = -torch.sum(torch.mul(slidingmask,(torch.mul(y_cml_pairs,torch.log(cml_out_beforemask))\
                                    +torch.mul(1-y_cml_pairs,torch.log(1-cml_out_beforemask)))))
    loss_emll = -torch.sum(torch.mul(slidingmask,(torch.mul(y_eml_pairs,torch.log(eml_out_beforemask))\
                                    +torch.mul(1-y_eml_pairs,torch.log(1-eml_out_beforemask)))))
    
    with torch.no_grad():
        cml_out = torch.mul(slidingmask,cml_out_beforemask)
        eml_out = torch.mul(slidingmask,eml_out_beforemask)
    
    # aaa = cml_out_beforemask
    # print(aaa[0])
    
    loss_total = lamb_1 * loss_isml + lamb_2 * loss_cmll + lamb_3 * loss_emll
        
    return loss_total,cml_out,eml_out

In [None]:
def inference(cml_out,eml_out,mode='avg'):
    eml_out_T = torch.permute(eml_out, (0,2,1))
    # print(eml_out_T.shape)
    if mode == 'avg':
        out = ((cml_out + eml_out_T)/2)>0.5
        out_ind = out.nonzero()
    elif mode == 'logic_and':
        cml_pair = cml_out>0.5
        eml_pair = eml_out>0.5
        out = torch.logical_and(cml_pair, eml_pair)
        out_ind = out.nonzero()
    elif mode == 'logic_or':
        cml_pair = cml_out>0.5
        eml_pair = eml_out>0.5
        out = torch.logical_or(cml_pair, eml_pair)
        out_ind = out.nonzero()

    return out_ind  # output index pairs: [batch,emo_clause,cause_clase]

    

In [None]:
def load_data(configs):
    if configs.split == 'split10':
        n_folds = 10
        configs.epochs = 20
    elif configs.split == 'split20':
        n_folds = 20
        configs.epochs = 15
    else:
        print('Unknown data split.')
        exit()
        
    fold_id = 1
    train_loader = build_train_data(configs, fold_id=fold_id)
    if configs.split == 'split20':
        val_loader = build_inference_data(configs, fold_id=fold_id, data_type='valid')
        
    test_loader = build_inference_data(configs, fold_id=fold_id, data_type='test')
    return train_loader, val_loader, test_loader

In [None]:
def test_wholesys():
    sliding_mask = slidingmask_gen()
    
    configs = Config()
    train_set, val_set, test_set = load_data(configs)
    data_iter = iter(train_set)
    instance = next(data_iter)
    # instance = next(data_iter)

    # doc_len_b: document length in a batch
    # adj_b: adj matrix in a batch, do not need this, #todo: will remove this
    # y_emotions_b: binary vector indicating emotion clause in a batch, -1 means no sentences in this document
    # y_causes_b: binary vector indicating cause clause in a batch, -1 means no sentences in this document
    # y_mask_b: binary vector indicating whether a sentence is valid in a batch, -1 means no sentences in this document
    # doc_couples_b: ground truth label in a batch
    # doc_id_b: document id in a batch
    # bert_token_b: input ids in a batch
    # bert_segment_b: segment ids in a batch
    # bert_masks_b: attention masks in a batch
    # bert_clause_b: [CLS] index for each doc in a batch
    doc_len_b, adj_b, y_emotions_b, y_causes_b, y_mask_b, doc_couples_b, doc_id_b, \
    bert_token_b, bert_segment_b, bert_masks_b, bert_clause_b = instance

    

    prev_model = Network()
    prev_model = prev_model.to(device=device)
    prev_model = torch.compile(prev_model)    # Pytorch 2.0 acceleration
    prev_model.train()
    s1 = prev_model(bert_token_b, bert_segment_b, bert_masks_b, bert_clause_b)
    s1 = input_padding(s1)
    s1.to(device=device)
    # print(s1.shape)
    # print(doc_couples_b)

    model = ISMLBlock(N,D,hidden_size,device)
    model = model.to(device=device)
    model = torch.compile(model)    # Pytorch 2.0 acceleration
    model.train()
    y_e_list,y_c_list,s_final,cml_scores,eml_scores = model(s1)

    # print(len(y_e_list),len(y_c_list),y_e_list[0].shape,y_c_list[0].shape)
    # print(s1)
    # print(s_final)
    # print(y_e_list)
    # print(y_c_list)

    loss_total,cml_out,eml_out = loss_calc(y_e_list,y_c_list,doc_couples_b,cml_scores,eml_scores,sliding_mask)
    print(loss_total)
    
    with torch.no_grad():
        model.eval()
        out_ind = inference(cml_out,eml_out,mode='avg')
        print(out_ind)

    # model.eval()
    # input_names = [ "actual_input" ]
    # output_names = [ "output" ]
    # torch.onnx.export(model, s1, 'secondhalf.onnx',input_names=input_names,\
    #              output_names=output_names,)

    # dots = torchviz.make_dot(s_final,params=dict(model.named_parameters()),show_attrs=False, show_saved=False)
    # dots.format = 'png'
    # dots.render('secondhalf_modelviz')

    

test_wholesys()