## Download Dependencies

In [None]:
!python -c "import torch; print(torch.version.cuda)"
!python -c "import torch; print(torch.__version__)"

10.1
1.7.0+cu101


In [None]:
!pip install torch-scatter==latest+cu101 -f https://pytorch-geometric.com/whl/torch-1.7.0.html
!pip install torch-sparse==latest+cu101 -f https://pytorch-geometric.com/whl/torch-1.7.0.html
!pip install torch-cluster==latest+cu101 -f https://pytorch-geometric.com/whl/torch-1.7.0.html
!pip install torch-spline-conv==latest+cu101 -f https://pytorch-geometric.com/whl/torch-1.7.0.html
!pip install torch-geometric
!pip install flair
!pip install laserembeddings
!pip install dataclasses
!pip install dill

### **Load FrameNet Graph into Graph Attention Networks**

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(2020) # seed for reproducible numbers

from torch_geometric.data import Data
from torch_geometric.nn import GATConv
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T

import matplotlib.pyplot as plt
%matplotlib notebook

import warnings
warnings.filterwarnings("ignore")

import nltk
nltk.download("framenet_v17")
from nltk.corpus import framenet as fn
import networkx as nx
import numpy as np

print("...creating networkx FN...")
G = nx.DiGraph()
for frame in fn.frames():
    G.add_node(frame.ID)
    for adj in frame.frameRelations:
        G.add_edge(adj.superFrame.ID, adj.subFrame.ID)
        G.add_edge(adj.subFrame.ID, adj.superFrame.ID)

# initialize frame embeddings with LASER sentence representations 
print("...embedding frames...")
!python -m laserembeddings download-models
from laserembeddings import Laser
laser = Laser()
sentences = [fn.frame(frameID).definition for frameID in G.nodes]
frame_embeddings = laser.embed_sentences(sentences, lang='en')

# convert networkx G into torch.geometric graph
print("...generating torch_geometric graph...")

x = torch.from_numpy(frame_embeddings)  # x.shape = (1221, 1024)
nodes_to_x = {node: i for i, node in enumerate(G.nodes)}  # map frame ID to index position in x
x_to_nodes = {i: node for i, node in enumerate(G.nodes)}  # reverse of nodes_to_x
edge_index = torch.Tensor(list(set([(nodes_to_x[src], nodes_to_x[tgt]) for src, tgt in G.edges]))).long()

data = Data(x=x, edge_index=edge_index.t().contiguous())
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = data.to(device)

### **Auxiliary Tasks**

#### Any-Language

In [None]:
# monolingual task
import nltk
nltk.download('framenet_v17')
from nltk.corpus import framenet as fn
import torch
from dataclasses import dataclass, field

@dataclass
class Annotation:
    annofile: str
    frameName: str
    luName: str = ''
    lu_idx: list = field(default_factory=list)  # [(start_LU_idx, end_LU_idx [exclusive of space], id), ...]
    fe_idx: list = field(default_factory=list) # [(start_FE_idx, end_FE_idx [exclusive of space], feName, id), ...]

    # tokenized by flair
    tokenized_text: str = ''
    tokenized_lu_idx: list = field(default_factory=list)  # [(token_idx, LU), ...]
    tokenized_frame_idx: list = field(default_factory=list)  # [(token_idx, frame), ...]
    tokenized_fe_idx: list = field(default_factory=list)  # [(token_idx, FE), ...]

#### Load Auxiliary Data

In [None]:
any_annos = torch.load("any-language-frames/annos_fn_pos_tags.pt")
bfn_annos = torch.load("/content/Capstone/frame_embeddings_BFN/annos.pt")

### **Actual Task**

### Load Data

In [None]:
# actual task
import nltk
nltk.download('framenet_v17')
from nltk.corpus import framenet as fn
import globalfn
from globalfn.annotations import annotation, all_annotations, annotation_annoID
from globalfn.alignments import all_alignments
import re

def extract_annoID(line):
    """Extract annoID from line"""
    return int(re.findall(r'\d+', line)[0])

div_D = {}

# load same en-pt annotation pairs
print("Load en-pt.results.txt")
with open("Capstone/en-pt.results.txt", "r") as rf:
    anno1 = anno2 = None
    for line in rf:
        if line.strip() == "===============================":
            anno1 = anno2 = None

        if "annoID" in line and anno1 is None:
            anno1 = extract_annoID(line)
        elif "annoID" in line and anno2 is None:
            anno2 = extract_annoID(line)

            src_frame_id = fn.frames(annotation_annoID('en', anno1).frameName)[0].ID
            tgt_frame_id = fn.frames(annotation_annoID('pt', anno2).frameName)[0].ID
            div_D[(anno1, anno2, 'pt')] = (src_frame_id, tgt_frame_id)

# load same en-de annotation pairs
print("Load en-de.results.txt")
with open("Capstone/en-de.results.txt", "r") as rf:
    anno1 = anno2 = None
    for line in rf:
        if line.strip() == "===============================":
            anno1 = anno2 = None

        if "annoID" in line and anno1 is None:
            anno1 = extract_annoID(line)
        elif "annoID" in line and anno2 is None:
            anno2 = extract_annoID(line)

            src_frame_id = fn.frames(annotation_annoID('en', anno1).frameName)[0].ID
            tgt_frame_id = fn.frames(annotation_annoID('de', anno2).frameName)[0].ID
            div_D[(anno1, anno2, 'de')] = (src_frame_id, tgt_frame_id)

# load diverging frames en-pt annotation pairs
print("Load en-pt.same.results.txt")
with open("Capstone/en-pt.same.results.txt", "r") as rf:
    anno1 = anno2 = None
    for line in rf:
        if line.strip() == "===============================":
            anno1 = anno2 = None

        if "annoID" in line and anno1 is None:
            anno1 = extract_annoID(line)
        elif "annoID" in line and anno2 is None:
            anno2 = extract_annoID(line)

            src_frame_id = fn.frames(annotation_annoID('en', anno1).frameName)[0].ID
            tgt_frame_id = fn.frames(annotation_annoID('pt', anno2).frameName)[0].ID
            div_D[(anno1, anno2, 'pt')] = (src_frame_id, tgt_frame_id)

# load diverging frames en-de annotation pairs
print("Load en-de.same.results.txt")
with open("Capstone/en-de.same.results.txt", "r") as rf:
    anno1 = anno2 = None
    for line in rf:
        if line.strip() == "===============================":
            anno1 = anno2 = None

        if "annoID" in line and anno1 is None:
            anno1 = extract_annoID(line)
        elif "annoID" in line and anno2 is None:
            anno2 = extract_annoID(line)

            src_frame_id = fn.frames(annotation_annoID('en', anno1).frameName)[0].ID
            tgt_frame_id = fn.frames(annotation_annoID('de', anno2).frameName)[0].ID
            div_D[(anno1, anno2, 'de')] = (src_frame_id, tgt_frame_id)

In [None]:
# load frames, LUs, and sentence annotations
frames = set()
lus = set()
sents = set()

for (anno1, anno2, lang), (src_frame_id, tgt_frame_id) in div_D.items():
    frames.add(src_frame_id)
    frames.add(tgt_frame_id)

    lus.add(annotation_annoID('en', anno1).luName)
    lus.add(annotation_annoID(lang, anno2).luName)

    sents.add(anno1)
    sents.add(anno2)
print(len(frames), len(lus), len(sents))  # check length

179 476 788


In [None]:
# load multilingual BERT model for LU embedding
from flair.data import Sentence
from flair.embeddings import TransformerWordEmbeddings
mbert = TransformerWordEmbeddings('distilbert-base-multilingual-cased', layers='-1')

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def get_lu_embedding(lang_model, anno_lang, anno_ID):
    """
    Create embeddings for LUs given language model, the annotation language, and the
    annotation ID
    """
    tokenized_text = annotation_annoID(anno_lang, anno_ID).tokenized_text
    sent = Sentence(tokenized_text, use_tokenizer=False)
    lang_model.embed(sent)

    tmp_embeds = list()
    for i, tok in enumerate(sent):
        if annotation_annoID(anno_lang, anno_ID).tokenized_frame_idx[i] != '-':
            tmp_embeds.append(tok.embedding)
    final_embeds = torch.mean(torch.stack(tmp_embeds), dim=0)
    return final_embeds

# target lexical units
tgt_L = list()
for _, anno_ID, lang in div_D.keys():
    print(anno_ID)
    embedding = get_lu_embedding(mbert, lang, anno_ID)
    tgt_L.append(embedding)

# source lexical units
src_L = list()
for anno_ID, _, lang in div_D.keys():
    print(anno_ID)
    embedding = get_lu_embedding(mbert, 'en', anno_ID)
    src_L.append(embedding)

tgt_lus = torch.stack(tgt_L).float().to(device)
src_lus = torch.stack(src_L).float().to(device)

In [None]:
# Load frames for both source and target LUs
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
data = data.to(device)
src_frames = torch.Tensor([nodes_to_x[src_frame_id] for src_frame_id, _ in div_D.values()]).long().to(device)
tgt_frames = torch.Tensor([nodes_to_x[tgt_frame_id] for _, tgt_frame_id in div_D.values()]).long().to(device)

In [None]:
# Load POS tags for both source and target LUs
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pos_to_ind = {
    'a': 2,
    'adv': 4,
    'art': 3,
    'c': 8,
    'idio': 6,
    'intj': 10,
    'n': 1,
    'num': 9,
    'prep': 5,
    'pron': 11,
    'scon': 7,
    'v': 0,
    'avp': 12
}

src_pos = list()
tgt_pos = list()

# convert part of speech tags associated with LUs into integers
for _, anno_ID, lang in div_D.keys():
    tgt_pos.append(pos_to_ind[annotation_annoID(lang, anno_ID).luName.split('.')[1]])

for anno_ID, _, lang in div_D.keys():
    src_pos.append(pos_to_ind[annotation_annoID('en', anno_ID).luName.split('.')[1]])

src_pos = torch.LongTensor(src_pos).to(device)
tgt_pos = torch.LongTensor(tgt_pos).to(device)

## **Experiments**

In [None]:
from torch_geometric.utils import dropout_adj
from torch_geometric.nn import GATConv

class NodeNorm(nn.Module):
    """
    Node normalization (regularization technique)
    """
    def __init__(self, unbiased=False, eps=1e-5):
        super(NodeNorm, self).__init__()
        self.unbiased = unbiased
        self.eps = eps

    def forward(self, x):
        mean = torch.mean(x, dim=1, keepdim=True)
        std = (torch.var(x, unbiased=self.unbiased, dim=1, keepdim=True) + self.eps).sqrt()
        x = (x - mean) / std
        return x

class GAT(torch.nn.Module):
    """
    Graph Attention Network
    """
    def __init__(self, data, hid=109, hid2=256, in_head=9, out_head=10):
        super(GAT, self).__init__()
        self.hid = hid
        self.hid2 = hid2
        self.in_head = in_head
        self.out_head = out_head
        
        self.node_norm = NodeNorm()
        self.conv1 = GATConv(data.num_features, self.hid, heads=self.in_head, dropout=0.6)
        self.conv2 = GATConv(self.hid*self.in_head, self.hid2, concat=False,
                             heads=self.out_head, dropout=0.6)

    def forward(self, data, training=True):
        x, edge_index = data.x, data.edge_index
        
        # DropEdge
        edge_index, _ = dropout_adj(data.edge_index, training=training)

        # Dropout before the GAT layer is used to avoid overfitting in small datasets like Cora.
        # One can skip them if the dataset is sufficiently large.
        x = nn.Dropout(p=0.4)(x)
        x = self.conv1(x, edge_index)
        x = F.gelu(x)
        x = self.node_norm(x)
        x = nn.Dropout(p=0.4)(x)
        x = self.conv2(x, edge_index)
        x = self.node_norm(x)
        return x

### Auxiliary Training

In [None]:
def get_lu_embedding_helper(lang_model, anno):
    """
    Use the `lang_model` to embed the sentence in `anno` to retrieve the word embedding
    for the lexical unit.
    """
    tokenized_text = anno.tokenized_text
    sent = Sentence(tokenized_text, use_tokenizer=False)
    lang_model.embed(sent)

    tmp_embeds = list()
    for i, tok in enumerate(sent):
        if anno.tokenized_frame_idx[i] != '-':
            tmp_embeds.append(tok.embedding)
    
    try:
        final_embeds = torch.mean(torch.stack(tmp_embeds), dim=0)
    except:
        print(anno)
        for i, tok in enumerate(sent):
            print(tok, anno.tokenized_frame_idx[i])

        print(tmp_embeds, len(anno.tokenized_frame_idx))
        assert False
    return final_embeds

def get_lu_embeddings(annos, start, end):
    """
    Get embeddings for lexical units for annos[start:end]
    annos: list of annotations
    start: start index
    end: end index
    """
    L = list()
    for i, anno in enumerate(annos[start:end]):
        if i % 1000 == 0:
            print(anno)
        embedding = get_lu_embedding_helper(mbert, anno)
        L.append(embedding)
    pre_lus = torch.stack(L).float().to(device)  # shape: (19927, 768)
    return pre_lus

In [None]:
def get_src_tgt_frames(annos, start, end):
    """
    Get source and target frames for annos[start:end]
    annos: list of annotations
    start: start index
    end: end index
    """
    fn15_to_fn17_mapping = torch.load("any-language-frames/fn15_to_fn17_mapping.pt")
    pre_src_frames = list()
    unavailable_frames = set()
    for i, anno in enumerate(annos[start:end]):
        if i % 1000 == 0:
            print(anno)
            
        frameName = anno.frameName
        frameName = fn15_to_fn17_mapping.get(frameName, frameName)
        retrieved_frame = fn.frames(frameName)[0]
        pre_src_frames.append(retrieved_frame.ID)
        try:
            assert retrieved_frame.name == frameName
        except:
            print(frameName, retrieved_frame.name)
            print(fn.frames(frameName))

    # correct frame-to-frame
    pre_src_frames = torch.Tensor([nodes_to_x[src_frame_id] for src_frame_id in pre_src_frames]).long().to(device)
    pre_tgt_frames = pre_src_frames.clone()
    return pre_src_frames, pre_tgt_frames

In [None]:
def get_pos(annos, start, end):
    """
    Get source and target POS tags for annos[start:end]
    annos: list of annotations
    start: start index
    end: end index
    """
    pos_to_ind = {
        'a': 2,
        'adv': 4,
        'art': 3,
        'c': 8,
        'idio': 6,
        'intj': 10,
        'n': 1,
        'num': 9,
        'prep': 5,
        'pron': 11,
        'scon': 7,
        'v': 0,
        'avp': 12
    }

    pre_pos = list()
    for i, anno in enumerate(annos[start:end]):
        pre_pos.append(pos_to_ind[anno.luName.split('.')[1]])

    pre_pos = torch.LongTensor(pre_pos).to(device)
    return pre_pos

In [None]:
import random
chosen_bfn_annos = random.choices(bfn_annos, k=15000)
annos = chosen_bfn_annos + any_annos
pre_lus = get_lu_embeddings(annos, 0, len(annos))
pre_src_frames, pre_tgt_frames = get_src_tgt_frames(annos, 0, len(annos))
pre_pos = get_pos(annos, 0, len(annos))

In [None]:
class MultiTaskLossWrapper(nn.Module):
    """
    Multi-task Loss Function weighted by homoscedastic uncertainty
    """
    def __init__(self, task_num):
        super(MultiTaskLossWrapper, self).__init__()
        self.task_num = task_num
        self.log_vars = nn.Parameter(torch.zeros((task_num)))

    def forward(self, losses):
        total_loss = 0
        for i in range(len(losses)):
            precision = torch.exp(-self.log_vars[i])
            total_loss += torch.sum(precision * losses[i] + self.log_vars[i], -1)
        total_loss = torch.mean(total_loss)
        return total_loss

### All

##### All + Nested CV

In [None]:
import random
def frame_pairs_in_relation_helper(frame_id1, frame_id2):
    frame1 = fn.frame(frame_id1)
    for frame_x in frame1.frameRelations:
        if frame_id2 == frame_x.subID or frame_id2 == frame_x.supID:
            return 1
    return 0


def generate_frame_pairs_in_relation(num_pairs):
    """
    Generate frame-LU pairs for binary frame prediction.
    num_pairs: number of frame-LU pairs
    """
    frame_pairs_in_relation = list()
    count = 0
    while len(frame_pairs_in_relation) < num_pairs:
        if random.random() < 0.5:
            f1 = random.choice(fn.frames()).ID
            f2 = random.choice(fn.frames()).ID
        else:
            f1 = random.choice(fn.frames())
            while not f1.frameRelations:
                f1 = random.choice(fn.frames())
            f2 = random.choice(f1.frameRelations)
            f2 = f2.subID if f2.subID != f1.ID else f2.supID
            f1 = f1.ID

        if f1 == f2:
            continue
        
        count += frame_pairs_in_relation_helper(f1, f2)
        frame_pairs_in_relation.append((f1, f2, frame_pairs_in_relation_helper(f1, f2)))

    fr_frame_1 = torch.LongTensor([nodes_to_x[f1] for f1, _, _ in frame_pairs_in_relation]).to(device)
    fr_frame_2 = torch.LongTensor([nodes_to_x[f1] for _, f2, _ in frame_pairs_in_relation]).to(device)
    fr_targets = torch.LongTensor([target for _, _, target in frame_pairs_in_relation]).to(device)
    return fr_frame_1, fr_frame_2, fr_targets

In [None]:
import nltk
nltk.download("framenet_v17")
from nltk.corpus import framenet as fn
import networkx as nx

# Load FrameNet into networkx graph
G = nx.DiGraph()
for frame in fn.frames():
    G.add_node(frame.ID)
    for adj in frame.frameRelations:
        G.add_edge(adj.superFrame.ID, adj.subFrame.ID)
        G.add_edge(adj.subFrame.ID, adj.superFrame.ID)

def generate_fr_dist_data(G, size=20000):
    """
    Generate frame-to-frame relations as auxiliary training data.
    G: FrameNet graph (in networkx)
    """
    src_nodes = []
    tgt_nodes = []
    dist = []
    all_nodes = list(G)
    while len(dist) < size:
        src_node = random.choice(all_nodes)
        tgt_node = random.choice(all_nodes)
        if src_node == tgt_node:
            continue
        src_nodes.append(nodes_to_x[src_node])
        tgt_nodes.append(nodes_to_x[tgt_node])
        if not nx.has_path(G, src_node, tgt_node):
            dist.append(0)
        else:
            dist.append(nx.shortest_path_length(G, src_node, tgt_node))

    src_nodes = torch.LongTensor(src_nodes).to(device)
    tgt_nodes = torch.LongTensor(tgt_nodes).to(device)
    dist = torch.FloatTensor(dist).to(device)
    return src_nodes, tgt_nodes, dist

In [None]:
def scramble_frames(pre_tgt_frames, p=0.3):
    """
    Randomly perturb frames for frame label reconstruction
    pre_tgt_frames: correct frame labels for annotations
    p: probability of perturbation
    """
    import random
    pre_scrambled_tgt_frames = list()
    pre_scrambled_targets = list()

    for i in range(len(pre_tgt_frames)):
        if random.random() < p:
            tmp = random.choice(list(x_to_nodes.keys()))
            pre_scrambled_tgt_frames.append(tmp)
            if tmp != pre_tgt_frames[i].item():
                pre_scrambled_targets.append(0)
            else:
                pre_scrambled_targets.append(1)
        else:
            pre_scrambled_tgt_frames.append(pre_tgt_frames[i].item())
            pre_scrambled_targets.append(1)

    pre_scrambled_tgt_frames = torch.LongTensor(pre_scrambled_tgt_frames).to(device)
    pre_scrambled_targets = torch.LongTensor(pre_scrambled_targets).to(device)
    return pre_scrambled_tgt_frames, pre_scrambled_targets

pre_scrambled_tgt_frames, pre_scrambled_targets = scramble_frames(pre_tgt_frames)

In [None]:
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].flatten().float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size).item())
        return res

In [None]:
class FFN_frame(torch.nn.Module):
    """
    Multi-class classifier layer for frame shift prediction and frame label reconstruction
    """
    def __init__(self, input_dim, num_class, pos_dim=16):
        super(FFN_frame, self).__init__()
        self.input_dim = input_dim
        self.num_class = num_class
        self.pos_dim = pos_dim

        self.pos_embedding = nn.Embedding(13, pos_dim)
        self.linear1 = nn.Linear(self.input_dim + pos_dim * 2, self.num_class)
    
    def forward(self, x, src_pos, tgt_pos):
        src_pos = self.pos_embedding(src_pos)
        tgt_pos = self.pos_embedding(tgt_pos)
        x = torch.cat([x, src_pos, tgt_pos], dim=1)
        out = self.linear1(x)
        return out

class FFN_pred_frame(torch.nn.Module):
    """
    Binary frame prediction classifier
    """
    def __init__(self, input_dim, num_class, pos_dim=16):
        super(FFN_pred_frame, self).__init__()
        self.input_dim = input_dim

        self.pos_embedding = nn.Embedding(13, pos_dim)
        self.linear1 = nn.Linear(self.input_dim + pos_dim, 2)
    
    def forward(self, x, pos):
        pos = self.pos_embedding(pos)
        x = torch.cat([x, pos], dim=1)
        out = self.linear1(x)
        return out

class FFN_relation_classifier(torch.nn.Module):
    """
    Binary frame-to-frame relation classifier
    """
    def __init__(self, input_dim, ffn_hid = 8, pos_dim=16):
        super(FFN_relation_classifier, self).__init__()
        self.input_dim = input_dim

        self.linear1 = nn.Linear(self.input_dim, 2)
    
    def forward(self, frame1, frame2):
        x = torch.cat([frame1, frame2], dim=1)
        out = self.linear1(x)
        return out

class FFN_pred_fr_dist(torch.nn.Module):
    """
    Nodes-apart frame distance prediction layer
    """
    def __init__(self, input_dim):
        super(FFN_pred_fr_dist, self).__init__()
        self.input_dim = input_dim
        self.linear1 = nn.Linear(self.input_dim, 1024)
        self.linear2 = nn.Linear(1024, 1)
    
    def forward(self, x1, x2):
        x = torch.cat([x1, x2], dim=1)
        out = self.linear1(x)
        out = nn.ReLU()(out)
        out = self.linear2(out)
        return out


def cross_val_split(data, fold, i):
    """
    Cross-validation split
    data: data (training/evaluation split)
    fold: total number of folds
    i: i-th fold
    """
    splits = [len(data) // fold] * (fold - 1) + [len(data) - sum([len(data) // fold] * (fold - 1))]
    held_out = data[i * splits[i]:(i + 1) * splits[i]]
    non_held_out = torch.cat([data[:i * splits[i]], data[(i + 1) * splits[i]:]])
    return non_held_out, held_out


def run():
    OUTER_FOLD = 5
    INNER_FOLD = 5
    ES_PATIENCE = 2
    test_accs = list()
    for fold_idx in range(OUTER_FOLD):
        print("Outer CV:", fold_idx)

        # outer CV
        all_train_src_frames, test_src_frames = cross_val_split(src_frames, OUTER_FOLD, fold_idx)
        all_train_src_lus, test_src_lus = cross_val_split(src_lus, OUTER_FOLD, fold_idx)
        all_train_src_pos, test_src_pos = cross_val_split(src_pos, OUTER_FOLD, fold_idx)
        all_train_tgt_lus, test_tgt_lus = cross_val_split(tgt_lus, OUTER_FOLD, fold_idx)
        all_train_tgt_frames, test_tgt_frames = cross_val_split(tgt_frames, OUTER_FOLD, fold_idx)
        all_train_tgt_pos, test_tgt_pos = cross_val_split(tgt_pos, OUTER_FOLD, fold_idx)

        # inner CV: features selection
        features = {}
        for inner_fold_idx in range(INNER_FOLD):
            # inner CV
            print("Inner CV:", inner_fold_idx)

            # initialize models, criterion, and optimizer
            model = GAT(data).to(device)
            aux_fr_dist_classifier = FFN_pred_fr_dist(256 * 2).to(device)
            aux_fr_classifier = FFN_relation_classifier(256 * 2).to(device)
            aux_frame_classifier = FFN_pred_frame(256 + 768, 1221).to(device)
            frame_classifier = FFN_frame(256 + 768 + 768, 1221).to(device)

            aux_fr_dist_criterion = nn.MSELoss()
            aux_fr_criterion = nn.CrossEntropyLoss()
            aux_frame_criterion = nn.CrossEntropyLoss()
            criterion = nn.CrossEntropyLoss()
            wrapper = MultiTaskLossWrapper(5)

            optimizer = torch.optim.Adam(list(model.parameters()) + list(frame_classifier.parameters()) + list(aux_frame_classifier.parameters()) + list(aux_fr_classifier.parameters()) + list(aux_fr_dist_classifier.parameters()), 
                                        lr=0.005, weight_decay=5e-4)

            # data for actual
            train_src_frames, val_src_frames = cross_val_split(all_train_src_frames, INNER_FOLD, inner_fold_idx)
            train_src_lus, val_src_lus = cross_val_split(all_train_src_lus, INNER_FOLD, inner_fold_idx)
            train_src_pos, val_src_pos = cross_val_split(all_train_src_pos, INNER_FOLD, inner_fold_idx)
            train_tgt_pos, val_tgt_pos = cross_val_split(all_train_tgt_pos, INNER_FOLD, inner_fold_idx)
            train_tgt_lus, val_tgt_lus = cross_val_split(all_train_tgt_lus, INNER_FOLD, inner_fold_idx)
            train_tgt_frames, val_tgt_frames = cross_val_split(all_train_tgt_frames, INNER_FOLD, inner_fold_idx)

            # randomly generated data for aux
            fr_frame_1, fr_frame_2, fr_targets = generate_frame_pairs_in_relation(20000)  # aux - fr
            src_nodes, tgt_nodes, dist = generate_fr_dist_data(G)  # aux - fr dist
            pre_scrambled_tgt_frames, pre_scrambled_targets = scramble_frames(pre_tgt_frames)  # aux - frame

            # model development
            model.train()
            optimizer.zero_grad()
            out = model(data)  


            min_val_loss = float('inf')
            epochs_no_improve = 0
            for epoch in range(1000):
                model.train()
                optimizer.zero_grad()
                out = model(data)

                # auxiliary task training - fr dist
                y = aux_fr_dist_classifier(out[src_nodes], out[tgt_nodes])
                aux_fr_dist_loss = aux_fr_dist_criterion(y, dist)

                # auxiliary task training - fr binary
                y = aux_fr_classifier(out[fr_frame_1], out[fr_frame_2])
                aux_fr_loss = aux_fr_criterion(y, fr_targets)

                # auxiliary task training - binary frame induction prediction 
                tmp_out = torch.cat([out[pre_scrambled_tgt_frames], pre_lus], dim=1)
                y = aux_frame_classifier(tmp_out, pre_pos)
                aux_frame_loss = aux_frame_criterion(y, pre_scrambled_targets)

                # auxiliary task training - frame restoration
                tmp_out = torch.cat([out[pre_scrambled_tgt_frames], pre_lus, pre_lus], dim=1)
                y = frame_classifier(tmp_out, pre_pos, pre_pos)
                aux_frame_2_loss = criterion(y, pre_tgt_frames)

                # actual task training
                tmp_out = torch.cat([out[train_src_frames], train_src_lus, train_tgt_lus], dim=1)
                y = frame_classifier(tmp_out, train_src_pos, train_tgt_pos)
                main_loss = criterion(y, train_tgt_frames)

                # compute total loss, which is the sum of the loss from the main 
                # FSP task and the auxiliary tasks
                total_loss = wrapper([aux_fr_dist_loss, aux_fr_loss, aux_frame_loss, aux_frame_2_loss, main_loss])
                total_loss.backward()
                optimizer.step()

                if epoch % 10 == 0:
                    print("Losses Breakdown:", total_loss.item(), aux_fr_dist_loss.item(), aux_fr_loss.item(), aux_frame_loss.item(), aux_frame_2_loss.item(), main_loss.item())
                    with torch.no_grad():
                        model.eval()
                        out = model(data, training=False)
                        out = torch.cat([out[train_src_frames], train_src_lus, train_tgt_lus], dim=1)
                        y = frame_classifier(out, train_src_pos, train_tgt_pos)
                        train_acc = sum(torch.argmax(y, dim=1) == train_tgt_frames).item()/torch.argmax(y, dim=1).shape[0]

                        out = model(data, training=False)
                        out = torch.cat([out[val_src_frames], val_src_lus, val_tgt_lus], dim=1)
                        y = frame_classifier(out, val_src_pos, val_tgt_pos)
                        val_acc = sum(torch.argmax(y, dim=1) == val_tgt_frames).item()/torch.argmax(y, dim=1).shape[0]
                        val_loss = criterion(y, val_tgt_frames)

                        # early stopping
                        if val_loss.item() < min_val_loss:
                            min_val_loss = val_loss.item()
                            epochs_no_improve = 0
                        else:
                            epochs_no_improve += 1
                            if epochs_no_improve >= ES_PATIENCE:
                                print("early stopping")
                                features[val_loss.item()] = {"best_epoches": epoch, 
                                                            'aux_fr': (fr_frame_1, fr_frame_2, fr_targets), 
                                                            'aux_fr_dist': (src_nodes, tgt_nodes, dist), 
                                                            'aux_frame': (pre_scrambled_tgt_frames, pre_scrambled_targets)}
                                break
                
                        
        ##################################################################################################################
        # model testing: retrain with all data -> eval with test
        # choose features
        min_val_loss = min(features.keys())
        best_features = features[min_val_loss]

        # initialize models, criterion, and optimizer
        model = GAT(data).to(device)
        aux_fr_dist_classifier = FFN_pred_fr_dist(256 * 2).to(device)
        aux_fr_classifier = FFN_relation_classifier(256 * 2).to(device)
        aux_frame_classifier = FFN_pred_frame(256 + 768, 1221).to(device)
        frame_classifier = FFN_frame(256 + 768 + 768, 1221).to(device)

        aux_fr_dist_criterion = nn.MSELoss()
        aux_fr_criterion = nn.CrossEntropyLoss()
        aux_frame_criterion = nn.CrossEntropyLoss()
        criterion = nn.CrossEntropyLoss()
        wrapper = MultiTaskLossWrapper(5)

        optimizer = torch.optim.Adam(list(model.parameters()) + list(frame_classifier.parameters()) + list(aux_frame_classifier.parameters()) + list(aux_fr_classifier.parameters()) + list(aux_fr_dist_classifier.parameters()), 
                                    lr=0.005, weight_decay=5e-4)

        fr_frame_1, fr_frame_2, fr_targets = best_features['aux_fr']
        src_nodes, tgt_nodes, dist = best_features['aux_fr_dist']
        pre_scrambled_tgt_frames, pre_scrambled_targets = best_features['aux_frame']

        model.train()
        for epoch in range(best_features['best_epoches']):
            optimizer.zero_grad()
            out = model(data)

            # auxiliary task training - fr dist
            y = aux_fr_dist_classifier(out[src_nodes], out[tgt_nodes])
            aux_fr_dist_loss = aux_fr_dist_criterion(y, dist)

            # auxiliary task training - fr
            y = aux_fr_classifier(out[fr_frame_1], out[fr_frame_2])
            aux_fr_loss = aux_fr_criterion(y, fr_targets)

            # auxiliary task training - binary frame
            tmp_out = torch.cat([out[pre_scrambled_tgt_frames], pre_lus], dim=1)
            y = aux_frame_classifier(tmp_out, pre_pos)
            aux_frame_loss = aux_frame_criterion(y, pre_scrambled_targets)

            # auxiliary task training - restoration
            tmp_out = torch.cat([out[pre_scrambled_tgt_frames], pre_lus, pre_lus], dim=1)
            y = frame_classifier(tmp_out, pre_pos, pre_pos)
            aux_frame_2_loss = criterion(y, pre_tgt_frames)

            # actual task training
            tmp_out = torch.cat([out[all_train_src_frames], all_train_src_lus, all_train_tgt_lus], dim=1)
            y = frame_classifier(tmp_out, all_train_src_pos, all_train_tgt_pos)
            main_loss = criterion(y, all_train_tgt_frames)

            # compute total loss, which is the sum of the loss from the main 
            # FSP task and the auxiliary tasks
            total_loss = wrapper([aux_fr_loss, aux_fr_dist_loss, aux_frame_loss, aux_frame_2_loss, main_loss])
            total_loss.backward()
            optimizer.step()
        
        
        with torch.no_grad():
            model.eval()
            out = model(data, training=False)
            out = torch.cat([out[test_src_frames], test_src_lus, test_tgt_lus], dim=1)
            y = frame_classifier(out, test_src_pos, test_tgt_pos)
            test_accs.append(accuracy(y, test_tgt_frames, topk=(5,))[0])
                
    return sum(test_accs)/len(test_accs)

results = [run() for _ in range(10)]

In [None]:
results