Imports libs

In [None]:
from collections import defaultdict
from itertools import count
import itertools
import random
import math
import json
import logging
import pathlib
import sys

import pandas as pd
import numpy as np
import tqdm # progree bar

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset, Sampler
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from transformers import AutoModel, AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM
from transformers.modeling_outputs import Seq2SeqLMOutput

DATA 

In [None]:
def read_json(path):
    with open(path, 'r', encoding='utf-8') as fin:
        data = json.load(fin)
    return data

def write_json(path, data):
    with open(path, 'w', encoding='utf-8') as fout:
        json.dump(data, fout)


class ARDDataset(Dataset):
    def __init__(self, path, is_test=False) -> None:
        super().__init__()
        self.is_test = is_test
        self.data = read_json(path)

    def __getitem__(self, index):
        sample = self.data[index]
        if self.is_test:
            return sample["id"], sample["word"], sample["gloss"],
        else:
            return sample["id"], sample["word"], sample["gloss"], sample["electra"], sample["bertseg"], sample['bertmsa']

    def __len__(self):
        return len(self.data)

BOS = "<seq>"
EOS = "</seq>"
PAD = "<pad/>"
UNK = "<unk/>"

SUPPORTED_ARCHS = (["sgns"])

# A dataset is a container object for the actual data
class JSONDataset(Dataset):
    """Reads a CODWOE JSON dataset"""

    def __init__(self, file, vocab=None, freeze_vocab=False, maxlen=256):
        """
        Construct a torch.utils.data.Dataset compatible with torch data API and
        codwoe data.
        args: `file` the path to the dataset file
              `vocab` a dictionary mapping strings to indices
              `freeze_vocab` whether to update vocabulary, or just replace unknown items with OOV token
              `maxlen` the maximum number of tokens per gloss
        """
        if vocab is None:
            self.vocab = defaultdict(count().__next__)
        else:
            self.vocab = defaultdict(count(len(vocab)).__next__)
            self.vocab.update(vocab)
        pad, eos, bos, unk = (
            self.vocab[PAD],
            self.vocab[EOS],
            self.vocab[BOS],
            self.vocab[UNK],
        )
        if freeze_vocab:
            self.vocab = dict(vocab)
        with open(file, "r") as istr:
            self.items = json.load(istr)
        # preparse data
        for json_dict in self.items:
            # in definition modeling test datasets, gloss targets are absent
            if "gloss" in json_dict:
                json_dict["gloss_tensor"] = torch.tensor(
                    [bos]
                    + [
                        self.vocab[word]
                        if not freeze_vocab
                        else self.vocab.get(word, unk)
                        for word in json_dict["gloss"].split()
                    ]
                    + [eos]
                )
                if maxlen:
                    json_dict["gloss_tensor"] = json_dict["gloss_tensor"][:maxlen]
            # in reverse dictionary test datasets, vector targets are absent
            for arch in SUPPORTED_ARCHS:
                if arch in json_dict:

                    json_dict[f"{arch}_tensor"] = torch.tensor(json_dict[arch])
            if "electra" in json_dict:
                json_dict["electra_tensor"] = torch.tensor(json_dict["electra"])
            elif "bertseg" in json_dict:
                json_dict["bertseg_tensor"] = torch.tensor(json_dict["bertseg"])
            elif "bertmsa" in json_dict:
                json_dict["bertmsa_tensor"] = torch.tensor(json_dict["bertmsa"])
        self.has_gloss = "gloss" in self.items[0]
        self.has_vecs = SUPPORTED_ARCHS[0] in self.items[0]
        self.has_electra = "electra" in self.items[0]
        self.has_bertseg = "bertseg" in self.items[0]
        self.has_bertmsa = "bertmsa" in self.items[0]
        self.itos = sorted(self.vocab, key=lambda w: self.vocab[w])

    def __len__(self):
        return len(self.items)

    def __getitem__(self, index):
        return self.items[index]

    # we're adding this method to simplify the code in our predictions of
    # glosses
    def decode(self, tensor):
        """Convert a sequence of indices (possibly batched) to tokens"""
        with torch.no_grad():
            if tensor.dim() == 2:
                # we have batched tensors of shape [Seq x Batch]
                decoded = []
                for tensor_ in tensor.t():
                    decoded.append(self.decode(tensor_))
                return decoded
            else:
                return " ".join(
                    [self.itos[i.item()] for i in tensor if i != self.vocab[PAD]]
                )

    def save(self, file):
        torch.save(self, file)

    @staticmethod
    def load(file):
        return torch.load(file)


# A sampler allows you to define how to select items from your Dataset. Torch
# provides a number of default Sampler classes
class TokenSampler(Sampler):
    """Produce batches with up to `batch_size` tokens in each batch"""

    def __init__(
        self, dataset, batch_size=200, size_fn=len, drop_last=False, shuffle=True
    ):
        """
        args: `dataset` a torch.utils.data.Dataset (iterable style)
              `batch_size` the maximum number of tokens in a batch
              `size_fn` a callable that yields the number of tokens in a dataset item
              `drop_last` if True and the data can't be divided in exactly the right number of batch, drop the last batch
              `shuffle` if True, shuffle between every iteration
        """
        self.dataset = dataset
        self.batch_size = batch_size
        self.size_fn = size_fn
        self._len = None
        self.drop_last = drop_last
        self.shuffle = True

    def __iter__(self):
        indices = range(len(self.dataset))
        if self.shuffle:
            indices = list(indices)
            random.shuffle(indices)
        i = 0
        selected = []
        numel = 0
        longest_len = 0
        for i in indices:
            if numel + self.size_fn(self.dataset[i]) > self.batch_size:
                if selected:
                    yield selected
                selected = []
                numel = 0
            numel += self.size_fn(self.dataset[i])
            selected.append(i)
        if selected and not self.drop_last:
            yield selected

    def __len__(self):
        if self._len is None:
            self._len = (
                sum(self.size_fn(self.dataset[i]) for i in range(len(self.dataset)))
                // self.batch_size
            )
        return self._len


# DataLoaders give access to an iterator over the dataset, using a sampling
# strategy as defined through a Sampler.
def get_dataloader(dataset, batch_size=200, shuffle=True):
    """produce dataloader.
    args: `dataset` a torch.utils.data.Dataset (iterable style)
          `batch_size` the maximum number of tokens in a batch
          `shuffle` if True, shuffle between every iteration
    """
    # some constants for the closures
    has_gloss = dataset.has_gloss
    has_vecs = dataset.has_vecs
    has_electra = dataset.has_electra
    has_bertseg = dataset.has_bertseg
    has_bertmsa = dataset.has_bertmsa
    PAD_idx = dataset.vocab[PAD]

    # the collate function has to convert a list of dataset items into a batch
    def do_collate(json_dicts):
        """collates example into a dict batch; produces ands pads tensors"""
        batch = defaultdict(list)
        for jdict in json_dicts:
            for key in jdict:
                batch[key].append(jdict[key])
        if has_gloss:
            batch["gloss_tensor"] = pad_sequence(
                batch["gloss_tensor"], padding_value=PAD_idx, batch_first=False
            )
        if has_vecs:
            for arch in SUPPORTED_ARCHS:
                batch[f"{arch}_tensor"] = torch.stack(batch[f"{arch}_tensor"])
        if has_electra:
            batch["electra_tensor"] = torch.stack(batch["electra_tensor"])
        if has_bertseg:
            batch["bertseg_tensor"] = torch.stack(batch["bertseg_tensor"])
        if has_bertmsa:
            batch["bertmsa_tensor"] = torch.stack(batch["bertmsa_tensor"])
        return dict(batch)

    if dataset.has_gloss:
        # we try to keep the amount of gloss tokens roughly constant across all
        # batches.
        def do_size_item(item):
            """retrieve tensor size, so as to batch items per elements"""
            return item["gloss_tensor"].numel()

        return DataLoader(
            dataset,
            collate_fn=do_collate,
            batch_sampler=TokenSampler(
                dataset, batch_size=batch_size, size_fn=do_size_item, shuffle=shuffle
            ),
        )
    else:
        # there's no gloss, hence no gloss tokens, so we use a default batching
        # strategy.
        return DataLoader(
            dataset, collate_fn=do_collate, batch_size=batch_size, shuffle=shuffle
        )

Models 

In [None]:
class AraT5RevDict(nn.Module):
    def __init__(self, max_len) -> None:
        super().__init__()
        model_config = AutoConfig.from_pretrained("UBC-NLP/AraT5v2-base-1024")
        self.base_model = AutoModelForSeq2SeqLM.from_config(model_config)
        self.linear = nn.Linear(self.base_model.config.hidden_size, max_len)

    def forward(self, input_ids, attention_mask, labels):
        outputs:Seq2SeqLMOutput = self.base_model(input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            return_dict=True
        )

        pooled_emb = (outputs.encoder_last_hidden_state * attention_mask.unsqueeze(2)).sum(dim=1) / attention_mask.sum(dim=1).unsqueeze(1)

        embedding = self.linear(pooled_emb)
        return outputs.loss, embedding

    def save(self, file):
        torch.save(self, file)
        print("\n--\nsave1\n--\n")

    @staticmethod
    def load(file):
        return torch.load(file)

class ARBERTRevDict(nn.Module):
    def __init__(self, args) -> None:
        super().__init__()
        if args.resume_train:
            self.base_model = AutoModel.from_pretrained(args.resume_file)
            raise NotImplementedError()
        else:
            if args.from_pretrained:
                self.base_model = AutoModel.from_pretrained(args.model_name)
            else:
                model_config = AutoConfig.from_pretrained(args.model_name)
                self.base_model = AutoModel.from_config(model_config)

        self.linear = nn.Linear(self.base_model.config.hidden_size, args.max_len)

    def forward(self, input_ids, token_type_ids , attention_mask):
        feats = self.base_model(input_ids=input_ids, attention_mask=attention_mask).pooler_output
        embedding = self.linear(feats)
        return embedding

    def save(self, file):
        self.base_model.save_pretrained(file,from_pt=True)
        print("\n--\nsave_pretrained\n--\n")
        # torch.save(self, file)

    @staticmethod
    def load(file):
        return AutoModel.from_pretrained(file)

class PositionalEncoding(nn.Module):
    """From PyTorch"""

    def __init__(self, d_model, dropout=0.1, max_len=4096):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x + self.pe[: x.size(0)]
        return self.dropout(x)



class RevdictModel(nn.Module):
    """A transformer architecture for Reverse Dictionary"""

    def __init__(
        self, vocab, d_model=256, n_head=4, n_layers=4, dropout=0.3, maxlen=512
    ):
        super(RevdictModel, self).__init__()
        self.d_model = d_model
        self.padding_idx = vocab[PAD]
        self.eos_idx = vocab[EOS]
        self.maxlen = maxlen

        self.embedding = nn.Embedding(len(vocab), d_model, padding_idx=self.padding_idx)
        self.positional_encoding = PositionalEncoding(
            d_model, dropout=dropout, max_len=maxlen
        )
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=n_head, dropout=dropout, dim_feedforward=d_model * 2
        )
        self.transformer_encoder = nn.TransformerEncoder(
            encoder_layer, num_layers=n_layers
        )
        self.dropout = nn.Dropout(p=dropout)
        self.e_proj = nn.Linear(d_model, d_model)
        for name, param in self.named_parameters():
            if param.dim() > 1:
                nn.init.xavier_uniform_(param)
            elif "bias" in name:
                nn.init.zeros_(param)
            else:  # gain parameters of the layer norm
                nn.init.ones_(param)

    def forward(self, gloss_tensor):
        src_key_padding_mask = gloss_tensor == self.padding_idx
        embs = self.embedding(gloss_tensor)
        src = self.positional_encoding(embs)
        transformer_output = self.dropout(
            self.transformer_encoder(src, src_key_padding_mask=src_key_padding_mask.t())
        )
        summed_embs = transformer_output.masked_fill(
            src_key_padding_mask.unsqueeze(-1), 0
        ).sum(dim=0)
        return self.e_proj(F.relu(summed_embs))

    @staticmethod
    def load(file):
        return torch.load(file)

    def save(self, file):
        torch.save(self, file)
        print("\n--\nsave2\n--\n")

Model uses

In [None]:
def train(args):
    # assert args.train_file is not None, "Missing dataset for training"
    # 1. get data, vocabulary, summary writer
    # logger.debug("Preloading data")
    ## make datasets
    # train_dataset = ARDDataset(args.train_file)
    # valid_dataset = ARDDataset(args.dev_file)

    ## make dataloader
    # train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=args.batch_size)
    # valid_dataloader = DataLoader(valid_dataset, batch_size=args.batch_size)
    ## make summary writer
    # summary_writer = SummaryWriter(args.save_dir / args.summary_logdir)
    # train_step = itertools.count()  # to keep track of the training steps for logging

    # 2. construct model
    ## Hyperparams
    # logger.debug("Setting up training environment")

    model = AraT5RevDict(args).to(args.device)
    tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    model.train()

    # 3. declare optimizer & loss_fn
    ## Hyperparams
    EPOCHS, LEARNING_RATE, BETA1, BETA2, WEIGHT_DECAY = args.num_epochs, 1.0e-4, 0.9, 0.999, 1.0e-6
    optimizer = optim.AdamW(
        model.parameters(),
        lr=LEARNING_RATE,
        betas=(BETA1, BETA2),
        weight_decay=WEIGHT_DECAY,
    )

    loss_fn = nn.MSELoss()

    vec_tensor_key = f"{args.target_arch}_tensor"

    best_cosine = 0

    # 4. train model
    for epoch in tqdm.trange(EPOCHS, desc="Epochs"):
        ## train loop
        pbar = tqdm.tqdm(
            desc=f"Train {epoch}", total=len(train_dataset), disable=None, leave=False
        )
        for ids, word, gloss, electra, bertseg, bertmsa in train_dataloader:
            optimizer.zero_grad()

            word_tokens = tokenizer(word, padding=True, return_tensors='pt').to(args.device)
            gloss_tokens = tokenizer(gloss, padding=True, return_tensors='pt').to(args.device)

            if args.target_arch == "electra":
                target_embs = torch.stack(electra, dim=1).to(args.device)
            elif args.target_arch =="bertseg":
                target_embs = torch.stack(bertseg, dim=1).to(args.device)
            elif args.target_arch =="bertmsa":
                target_embs = torch.stack(bertmsa, dim=1).to(args.device)

            target_embs = target_embs.float()

            ce_loss, pred_embs = model(
                gloss_tokens["input_ids"],
                gloss_tokens["attention_mask"],
                word_tokens["input_ids"],
            )

            mse_loss = loss_fn(pred_embs, target_embs)
            loss = args.ce_loss_weight * ce_loss + mse_loss
            loss.backward()

            # keep track of the train loss for this step
            next_step = next(train_step)
            summary_writer.add_scalar(
                "revdict-train/cos",
                F.cosine_similarity(pred_embs, target_embs).mean().item(),
                next_step,
            )
            summary_writer.add_scalar("revdict-train/mse", loss.item(), next_step)
            optimizer.step()
            pbar.update(target_embs.size(0))

        pbar.close()


        ## eval loop
        if args.dev_file:
            model.eval()
            with torch.no_grad():
                sum_dev_loss, sum_cosine, sum_rnk = 0.0, 0.0, 0.0
                pbar = tqdm.tqdm(
                    desc=f"Eval {epoch}",
                    total=len(valid_dataset),
                    disable=None,
                    leave=False,
                )
                pred_embs_list, target_embs_list = [], []
                for ids, word, gloss, electra, bertseg, bertmsa in valid_dataloader:
                    # word_tokens = tokenizer(word, padding=True, return_tensors='pt').to(args.device)
                    # gloss_tokens = tokenizer(gloss, max_length=512, padding=True, truncation=True, return_tensors='pt').to(args.device)

                    word_tokens = tokenizer(word, padding=True, return_tensors='pt').to(args.device)
                    gloss_tokens = tokenizer(gloss, padding=True, return_tensors='pt').to(args.device)

                    if args.target_arch == "electra":
                        target_embs = torch.stack(electra, dim=1).to(args.device)
                    elif args.target_arch == "bertseg":
                        target_embs = torch.stack(bertseg, dim=1).to(args.device)
                    elif args.target_arch == "bertmsa":
                        target_embs = torch.stack(bertmsa, dim=1).to(args.device)

                    target_embs = target_embs.float()

                    ce_loss, pred_embs = model(
                        gloss_tokens["input_ids"],
                        gloss_tokens["attention_mask"],
                        word_tokens["input_ids"],
                    )

                    mse_loss = loss_fn(pred_embs, target_embs)
                    loss = args.ce_loss_weight * ce_loss + mse_loss

                    # sum_dev_loss += (
                    #     F.mse_loss(pred_embs, target_embs, reduction="none").mean(1).sum().item()
                    # )
                    sum_dev_loss += loss.item()
                    sum_cosine += F.cosine_similarity(pred_embs, target_embs).sum().item()

                    # sum_rnk += rank_cosine(pred_embs, target_embs)

                    pred_embs_list.append(pred_embs.cpu())
                    target_embs_list.append(target_embs.cpu())

                    pbar.update(target_embs.size(0))

                sum_rnk = rank_cosine(torch.cat(pred_embs_list, dim=0), torch.cat(target_embs_list, dim=0))

                pbar = tqdm.tqdm(
                    desc=f"Eval {epoch} cos: "+str(sum_cosine / len(valid_dataset))+" mse: "+str( sum_dev_loss / len(valid_dataset) )+" rnk: "+str(sum_rnk/ len(valid_dataset))+ " sum_rnk: "+str(sum_rnk)+" len of dev: "+str(len(valid_dataset)) +"\n",
                    total=len(valid_dataset),
                    disable=None,
                    leave=False,
                )

                if sum_cosine >= best_cosine:
                    best_cosine = sum_cosine
                    print(f"Saving Best Checkpoint at Epoch {epoch} best cosine {best_cosine} .")
                    model.save(args.save_dir / args.best_model)


                # keep track of the average loss on dev set for this epoch
                summary_writer.add_scalar(
                    "revdict-dev/cos", sum_cosine / len(valid_dataset), epoch
                )
                summary_writer.add_scalar(
                    "revdict-dev/mse", sum_dev_loss / len(valid_dataset), epoch
                )
                summary_writer.add_scalar(
                    "revdict-dev/rnk", sum_rnk / len(valid_dataset), epoch
                )
                pbar.close()
                model.train()

        model.save(args.save_dir / "model_epoch.pt")

    # 5. save result
    model.save(args.save_dir / args.last_state_model)

In [None]:
import os
import re
import json
import random
from collections import defaultdict
from itertools import count

import string
import math

import pandas as pd
import numpy as np



from sklearn.model_selection import train_test_split


import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
import pyarabic.araby as arab

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset, Sampler

from transformers import BertTokenizer, BertModel, AutoModel, AutoConfig, AutoModelForSeq2SeqLM
from transformers.modeling_outputs import Seq2SeqLMOutput


import tensorflow as tf
from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras.layers import Embedding, LSTM, Dense
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.utils import to_categorical

# preprocess_text functions

In [None]:
stop_words = list(set(stopwords.words('arabic')))
a_dict = r"""
                             ّ    | # Tashdid
                             َ    | # Fatha
                             ً    | # Tanwin Fath
                             ُ    | # Damma
                             ٌ    | # Tanwin Damm
                             ِ    | # Kasra
                             ٍ    | # Tanwin Kasr
                             ْ    | # Sukun
                            ـ    | # Tatwil/Kashida
                     """
regex_pattern = (
    "\U0001F600-\U0001F64F"+  # emoticons {😀 , 😆} 
    "\U0001F300-\U0001F5FF"+  # symbols & pictographs {🌍 , 🌞}
    "\U0001F680-\U0001F6FF"+  # transport & map symbols {🚌 , 🚕 }
    "\U0001F1E0-\U0001F1FF"   # flags (iOS) { 🇺🇸 , 🇨🇦 }
) 
def preprocess_text(text):
    # Remove special characters {& $ @} and punctuation {. , ? !}
    text = re.sub(r'[^\w\s]', '', text)

    # Remove Arabic diacritics
    text = re.sub(a_dict, '', text)

    # Remove emoji characters 
    text = re.sub(f"[{regex_pattern}]", '', text)    
    
    # Tokeniz The Sentence into tokens
    tokens = word_tokenize(text)
    tokens = [word for word in tokens if word not in stop_words and len(word) > 1]
    preprocessed_text = ' '.join(tokens)
    
    return preprocessed_text

In [None]:
model = AraT5RevDict(256).to("cpu")
tokenizer = AutoTokenizer.from_pretrained("UBC-NLP/AraT5v2-base-1024")

In [None]:
dataset = ARDDataset("dev.json")

In [None]:
data = dataset[0]
# data

In [None]:
_id = data[0]
word = data[1]
gloss = data[2]
electra = data[3]

In [1]:
word
gloss
electra




[0.2265725881,
 -0.3225077391,
 0.5538389087,
 0.5800347328,
 -0.0342517458,
 -0.0777392015,
 0.2868331671,
 0.689817667,
 0.5971964598,
 0.4828594625,
 0.0246522464,
 0.2119424343,
 0.0674103796,
 0.3009205461,
 -0.8413527608,
 -0.0168413986,
 0.2170748115,
 0.5729941726,
 -0.1633950323,
 -0.3820695877,
 0.5487265587,
 -0.7529748678,
 0.681463778,
 0.6414878368,
 0.4434165657,
 -0.6303778291,
 0.4416134059,
 0.0699960962,
 1.3176093102,
 -0.3285644054,
 0.2450650483,
 -0.0553470254,
 0.49709481,
 -0.4185130298,
 0.5858063698,
 -0.0439687595,
 1.3485478163,
 -0.9497126937,
 0.1511357725,
 0.2152087986,
 -0.4863995016,
 0.3027804792,
 -0.0959090516,
 -0.0409083292,
 -0.8452766538,
 -0.1000854522,
 -0.4719405472,
 -0.0783871412,
 0.788713932,
 -0.1880373061,
 0.2888511419,
 1.185285449,
 0.5520726442,
 -0.8278334141,
 0.6819879413,
 -0.2602242529,
 0.1381222904,
 0.1392377019,
 0.0435579084,
 -0.4877517521,
 -0.0483157523,
 -0.2761443555,
 0.4849563539,
 -0.4398656189,
 -0.0801035389,
 0

In [None]:
# word_tokens = tokenizer(word, padding=True, return_tensors='pt').to("cpu")
# gloss_tokens = tokenizer(gloss, padding=True, return_tensors='pt').to("cpu")

In [2]:
word_tokens = tokenizer(word, padding=True, return_tensors='pt').to("cpu")
gloss_tokens = tokenizer(preprocess_text(gloss), padding=True, return_tensors='pt').to("cpu")

print(word_tokens)
print(gloss_tokens)

{'input_ids': tensor([[46269,  4412, 30597,     1]]), 'attention_mask': tensor([[1, 1, 1, 1]])}
{'input_ids': tensor([[20381,   219, 29232,  2437,     1]]), 'attention_mask': tensor([[1, 1, 1, 1, 1]])}





{'input_ids': tensor([[46269,  4412, 30597,     1]]), 'attention_mask': tensor([[1, 1, 1, 1]])}

In [None]:
# target_embs = torch.stack(electra, dim=1).to("cpu")
# target_embs = target_embs.float()
ce_loss, pred_embs = model(
    gloss_tokens["input_ids"],
    gloss_tokens["attention_mask"],
    word_tokens["input_ids"],
)

In [4]:
pred_embs
# ce_loss




tensor(135.2611, grad_fn=<NllLossBackward0>)

In [None]:
model = AraT5RevDict(256).to("cpu")
tokenizer = AutoTokenizer.from_pretrained("UBC-NLP/AraT5v2-base-1024")

model.train()
# 3. declare optimizer & loss_fn
## Hyperparams
EPOCHS, LEARNING_RATE, BETA1, BETA2, WEIGHT_DECAY = 5, 1.0e-4, 0.9, 0.999, 1.0e-6
optimizer = optim.AdamW(
    model.parameters(),
    lr=LEARNING_RATE,
    betas=(BETA1, BETA2),
    weight_decay=WEIGHT_DECAY,
)
loss_fn = nn.MSELoss()


vec_tensor_key = f"{args.target_arch}_tensor"

best_cosine = 0

# 4. train model
for epoch in tqdm.trange(EPOCHS, desc="Epochs"):
    ## train loop
    pbar = tqdm.tqdm(
        desc=f"Train {epoch}", total=len(train_dataset), disable=None, leave=False
    )
    for ids, word, gloss, electra, bertseg, bertmsa in train_dataloader:
        optimizer.zero_grad()

        word_tokens = tokenizer(word, padding=True, return_tensors='pt').to("cpu")
        gloss_tokens = tokenizer(gloss, padding=True, return_tensors='pt').to("cpu")

        # if args.target_arch == "electra":
        #     target_embs = torch.stack(electra, dim=1).to(args.device)
        # elif args.target_arch =="bertseg":
        #     target_embs = torch.stack(bertseg, dim=1).to(args.device)
        # elif args.target_arch =="bertmsa":
        #     target_embs = torch.stack(bertmsa, dim=1).to(args.device)

        target_embs = torch.stack(electra, dim=1).to("cpu")
        target_embs = target_embs.float()
        ce_loss, pred_embs = model(
            gloss_tokens["input_ids"],
            gloss_tokens["attention_mask"],
            word_tokens["input_ids"],
        )

        mse_loss = loss_fn(pred_embs, target_embs)
        loss = 1 * ce_loss + mse_loss
        loss.backward()

        # keep track of the train loss for this step
        # next_step = next(train_step)
        # summary_writer.add_scalar(
        #     "revdict-train/cos",
        #     F.cosine_similarity(pred_embs, target_embs).mean().item(),
        #     next_step,
        # )
        # summary_writer.add_scalar("revdict-train/mse", loss.item(), next_step)
        # optimizer.step()
        # pbar.update(target_embs.size(0))

    # pbar.close()


    ## eval loop
    # if args.dev_file:
        # model.save(args.save_dir / "model_epoch.pt")

# 5. save result
# model.save(args.save_dir / args.last_state_model)

eval

In [None]:
# model.eval()
# with torch.no_grad():
#     sum_dev_loss, sum_cosine, sum_rnk = 0.0, 0.0, 0.0
#     pbar = tqdm.tqdm(
#         desc=f"Eval {epoch}",
#         total=len(valid_dataset),
#         disable=None,
#         leave=False,
#     )
#     pred_embs_list, target_embs_list = [], []
#     for ids, word, gloss, electra, bertseg, bertmsa in valid_dataloader:
#         # word_tokens = tokenizer(word, padding=True, return_tensors='pt').to(args.device)
#         # gloss_tokens = tokenizer(gloss, max_length=512, padding=True, truncation=True, return_tensors='pt').to(args.device)

#         word_tokens = tokenizer(word, padding=True, return_tensors='pt').to(args.device)
#         gloss_tokens = tokenizer(gloss, padding=True, return_tensors='pt').to(args.device)

#         if args.target_arch == "electra":
#             target_embs = torch.stack(electra, dim=1).to(args.device)
#         elif args.target_arch == "bertseg":
#             target_embs = torch.stack(bertseg, dim=1).to(args.device)
#         elif args.target_arch == "bertmsa":
#             target_embs = torch.stack(bertmsa, dim=1).to(args.device)

#         target_embs = target_embs.float()

#         ce_loss, pred_embs = model(
#             gloss_tokens["input_ids"],
#             gloss_tokens["attention_mask"],
#             word_tokens["input_ids"],
#         )

#         mse_loss = loss_fn(pred_embs, target_embs)
#         loss = args.ce_loss_weight * ce_loss + mse_loss

#         # sum_dev_loss += (
#         #     F.mse_loss(pred_embs, target_embs, reduction="none").mean(1).sum().item()
#         # )
#         sum_dev_loss += loss.item()
#         sum_cosine += F.cosine_similarity(pred_embs, target_embs).sum().item()

#         # sum_rnk += rank_cosine(pred_embs, target_embs)

#         pred_embs_list.append(pred_embs.cpu())
#         target_embs_list.append(target_embs.cpu())

#         pbar.update(target_embs.size(0))

#     sum_rnk = rank_cosine(torch.cat(pred_embs_list, dim=0), torch.cat(target_embs_list, dim=0))

#     pbar = tqdm.tqdm(
#         desc=f"Eval {epoch} cos: "+str(sum_cosine / len(valid_dataset))+" mse: "+str( sum_dev_loss / len(valid_dataset) )+" rnk: "+str(sum_rnk/ len(valid_dataset))+ " sum_rnk: "+str(sum_rnk)+" len of dev: "+str(len(valid_dataset)) +"\n",
#         total=len(valid_dataset),
#         disable=None,
#         leave=False,
#     )

#     if sum_cosine >= best_cosine:
#         best_cosine = sum_cosine
#         print(f"Saving Best Checkpoint at Epoch {epoch} best cosine {best_cosine} .")
#         model.save(args.save_dir / args.best_model)


#     # keep track of the average loss on dev set for this epoch
#     summary_writer.add_scalar(
#         "revdict-dev/cos", sum_cosine / len(valid_dataset), epoch
#     )
#     summary_writer.add_scalar(
#         "revdict-dev/mse", sum_dev_loss / len(valid_dataset), epoch
#     )
#     summary_writer.add_scalar(
#         "revdict-dev/rnk", sum_rnk / len(valid_dataset), epoch
#     )
#     pbar.close()
#     model.train()