In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [1]:
import os, itertools, tqdm, codecs, random, pickle
from collections import Counter
import numpy as np
import torch
from torch import nn
import numpy as np
import torch.nn.functional as F
import torch.utils.data as D
#DRIVE_PATH = "/content/drive/MyDrive/NLP/HW3"
DRIVE_PATH = "/Users/sepehr/Desktop/Uni/Courses/NLP/HW3"

In [2]:
from importlib.machinery import SourceFileLoader
ModelsModule = SourceFileLoader("ModelsModule", DRIVE_PATH+'/Models.py').load_module()
DatasetsModule = SourceFileLoader("DatasetsModule", DRIVE_PATH+'/Datasets.py').load_module()
from ModelsModule import LSTM, SiameseLSTM, myword2vec
from DatasetsModule import MasnaviDataset, RhymeBatchSampler

# Utils

In [3]:
PERSIAN_EMBEDDINGS = "اأآبپتثجچحخدذرزژسشصضطظعغفقکگلمنوهیئ"

def do_rhyme_words(w1,w2):
    return w1.endswith(w2) or w2.endswith(w1) or w2 == w1 or (w1[-2:]==w2[-2:] and w1[-2:]!='ست')

def do_rhyme_mesras(m1, m2):
    return do_rhyme_words( m1[-1], m2[-1] )

def is_masnavi(curr_beyt, next_beyt):
    if not do_rhyme_mesras(*curr_beyt):
        return False
    if (next_beyt is not None) and not do_rhyme_mesras(*next_beyt):
        return False
    return True

def get_beyt_rhyme(mesra1, mesra2):
    for idx in range(1,min(len(mesra1),len(mesra2))):
        if do_rhyme_words(mesra1[-idx], mesra2[-idx]) and mesra1[-idx] != mesra2[-idx]:
            return tuple(sorted((mesra1[-idx], mesra2[-idx])))
    return None

# Dataset

**Skip Bellow Cells And Only Run Loader Cell**

In [5]:
mesras = {}
for filename in os.listdir(DRIVE_PATH+'/Persian_poems_corpus/normalized'):
    mesras[filename[:-9]] = list(filter(lambda m:len(m)>2, [x.strip().split() for x in codecs.open(DRIVE_PATH+f'/Persian_poems_corpus/normalized/{filename}','rU','utf-8').readlines()]))
with open(DRIVE_PATH+'/datasets/mesras.pickle', 'wb') as f:
    pickle.dump(mesras, f)

In [6]:
masnavis = []
for poet,p_mesras in mesras.items():
    for idx in range(0,len(p_mesras),2):
        if idx+1 >= len(p_mesras):
            break
        curr_beyt = (p_mesras[idx], p_mesras[idx+1])
        next_beyt = None if idx+3 >= len(p_mesras) else (p_mesras[idx+2], p_mesras[idx+3])
        if is_masnavi(curr_beyt, next_beyt):
            masnavis.append(curr_beyt)
with open(DRIVE_PATH+'/datasets/masnavis.pickle', 'wb') as f:
    pickle.dump(masnavis, f)

In [7]:
qazals = []
for poet,p_mesras in mesras.items():
    idx = 0
    while idx < len(p_mesras):
        if idx+1 >= len(p_mesras):
            break
        curr_beyt = (p_mesras[idx], p_mesras[idx+1])
        idx += 2
        if do_rhyme_mesras(*curr_beyt):
            qazal = [curr_beyt]
            next_beyt = None if idx+1 >= len(p_mesras) else (p_mesras[idx], p_mesras[idx+1])
            while (next_beyt is not None) and do_rhyme_mesras(next_beyt[1], curr_beyt[1]):
                qazal.append(next_beyt)
                idx+=2
                next_beyt = None if idx+1 >= len(p_mesras) else (p_mesras[idx], p_mesras[idx+1])
            if len(qazal) > 1:
                qazals.append(qazal)
with open(DRIVE_PATH+'/datasets/qazals.pickle', 'wb') as f:
    pickle.dump(qazals, f)

In [11]:
rhymes = []
for beyt in masnavis:
    rhymes.append( get_beyt_rhyme(*beyt) )
for qazal in qazals:
    rhymes.append(get_beyt_rhyme(*qazal[0]))
   
    #option 1
    for idx in range(1,len(qazal)):
        rhymes.append( get_beyt_rhyme(qazal[0][1], qazal[idx][1]) )
    
    
    # #option2
    # for idx in range(0,len(qazal)-1):
    #     rhymes.append( get_beyt_rhyme(qazal[idx][1], qazal[idx+1][1]) )
    
    # option 3
#     for idx1 in range(0,len(qazal)):
#         for idx2 in range(idx1+1,len(qazal)):
#             rhymes.append( get_beyt_rhyme(qazal[idx1][1], qazal[idx2][1]) )
rhymes = list(set([r for r in rhymes if (r is not None)]))
with open(DRIVE_PATH+'/datasets/rhymes.pickle', 'wb') as f:
    pickle.dump(rhymes, f)

## Datasets Loader

In [4]:
with open(DRIVE_PATH+'/datasets/mesras.pickle', 'rb') as f:
    mesras = pickle.load(f)
with open(DRIVE_PATH+'/datasets/masnavis.pickle', 'rb') as f:
    masnavis = pickle.load(f)
with open(DRIVE_PATH+'/datasets/qazals.pickle', 'rb') as f:
    qazals = pickle.load(f)
with open(DRIVE_PATH+'/datasets/rhymes.pickle', 'rb') as f:
    rhymes = pickle.load(f)

### Pretrained Word2Vec

In [5]:
#!pip3 install gensim
from gensim.models import Word2Vec
from gensim.models import KeyedVectors

class LiteratureWord2Vec(object):
    def __init__(self):
        super(LiteratureWord2Vec, self).__init__()
        self.corpus = KeyedVectors.load_word2vec_format(DRIVE_PATH+'/datasets/farsi_literature_word2vec_model.txt', binary=False)
        self.emb_dim = 100
        
    def add_new_word(self, w):
        emb = self.corpus['ا'].copy()
        while self.corpus.most_similar([emb], topn=1)[0][1]>0.5:
            emb = np.random.normal(0,1,self.emb_dim)
        self.corpus.add_vector(w, emb)
        self.corpus.fill_norms(force=True)

    def __call__(self, words, pad_to=None):
        if isinstance(words, str):
            words = [words]
        embeddings = []
        for word in words:
            emb = None
            if word in self.corpus:
                emb = self.corpus[word]
            else:
                for i in range(len(word)):
                    if word[:i] in self.corpus and word[i:] in self.corpus:
                        emb = self.corpus[word[:i]]+self.corpus[word[i:]]
            if emb is None:
                unk_emb = self.corpus['ا'].copy()
                unk_emb[unk_emb!=0]=0 # UNK word embedding - #TODO: its bad for cosine similarity
                emb = unk_emb
            embeddings.append( torch.Tensor(emb) )
            
        return torch.vstack(embeddings) 

### Corpus and Word Embeddings Model

In [6]:
all_mesras = list(itertools.chain(*mesras.values()))
unigrams = list(itertools.chain(*all_mesras))
unigrams = list(set([w for w in unigrams]))

literatureWord2Vec = LiteratureWord2Vec()

special_words = ["__PAD__", "__BOM__", "__EOM__"]
for w in special_words:
    unigrams.append(w)
    literatureWord2Vec.add_new_word(w)



# N-Gram Language Model

In [42]:
#@title Default title text
# Modified version of 
# https://github.com/joshualoehr/ngram-language-model/blob/master/language_model.py
from itertools import product
import math
import nltk


class LanguageModel(object):
    """An n-gram language model trained on a given corpus.
    
    For a given n and given training corpus, constructs an n-gram language
    model for the corpus by:
    1. preprocessing the corpus (adding SOS/EOS/UNK tokens)
    2. calculating (smoothed) probabilities for each n-gram
    Also contains methods for calculating the perplexity of the model
    against another corpus, and for generating sentences.
    Args:
        train_data (list of str): list of sentences comprising the training corpus.
        n (int): the order of language model to build (i.e. 1 for unigram, 2 for bigram, etc.).
        laplace (int): lambda multiplier to use for laplace smoothing (default 1 for add-1 smoothing).
    """

    SOS = "__BOM__"
    EOS = "__EOM__"
    UNK = "<UNK>"
    
    def __init__(self, train_data, n, laplace=1):
        self.n = n
        self.vocab = dict()
        self.laplace = laplace
        self.tokens = self.preprocess(train_data, n)
        self.vocab  = nltk.FreqDist(self.tokens)
        self.model  = self._create_model()
        self.masks  = list(reversed(list(product((0,1), repeat=n))))

    def _smooth(self):
        """Apply Laplace smoothing to n-gram frequency distribution.
        
        Here, n_grams refers to the n-grams of the tokens in the training corpus,
        while m_grams refers to the first (n-1) tokens of each n-gram.
        Returns:
            dict: Mapping of each n-gram (tuple of str) to its Laplace-smoothed 
            probability (float).
        """
        vocab_size = len(self.vocab)

        n_grams = nltk.ngrams(self.tokens, self.n)
        n_vocab = nltk.FreqDist(n_grams)

        m_grams = nltk.ngrams(self.tokens, self.n-1)
        m_vocab = nltk.FreqDist(m_grams)

        def smoothed_count(n_gram, n_count):
            m_gram = n_gram[:-1]
            m_count = m_vocab[m_gram]
            return (n_count + self.laplace) / (m_count + self.laplace * vocab_size)

        return { n_gram: smoothed_count(n_gram, count) for n_gram, count in n_vocab.items() }

    def _create_model(self):
        """Create a probability distribution for the vocabulary of the training corpus.
        
        If building a unigram model, the probabilities are simple relative frequencies
        of each token with the entire corpus.
        Otherwise, the probabilities are Laplace-smoothed relative frequencies.
        Returns:
            A dict mapping each n-gram (tuple of str) to its probability (float).
        """
        if self.n == 1:
            num_tokens = len(self.tokens)
            return { (unigram,): count / num_tokens for unigram, count in self.vocab.items() }
        else:
            return self._smooth()

    def _convert_oov(self, ngram):
        """Convert, if necessary, a given n-gram to one which is known by the model.
        Starting with the unmodified ngram, check each possible permutation of the n-gram
        with each index of the n-gram containing either the original token or <UNK>. Stop
        when the model contains an entry for that permutation.
        This is achieved by creating a 'bitmask' for the n-gram tuple, and swapping out
        each flagged token for <UNK>. Thus, in the worst case, this function checks 2^n
        possible n-grams before returning.
        Returns:
            The n-gram with <UNK> tokens in certain positions such that the model
            contains an entry for it.
        """
        mask = lambda ngram, bitmask: tuple((token if flag == 1 else "<UNK>" for token,flag in zip(ngram, bitmask)))

        ngram = (ngram,) if type(ngram) is str else ngram
        for possible_known in [mask(ngram, bitmask) for bitmask in self.masks]:
            if possible_known in self.model:
                return possible_known

    def perplexity(self, test_data):
        """Calculate the perplexity of the model against a given test corpus.
        
        Args:
            test_data (list of str): sentences comprising the training corpus.
        Returns:
            The perplexity of the model as a float.
        
        """
        test_tokens = self.preprocess(test_data, self.n)
        test_ngrams = nltk.ngrams(test_tokens, self.n)
        N = len(test_tokens)

        known_ngrams  = [self._convert_oov(ngram) for ngram in test_ngrams]
        probabilities = [self.model[ngram] for ngram in known_ngrams]
        
        for x,y in zip(known_ngrams, probabilities):
            print(x,y)
        
        return math.exp((-1/N) * sum(map(math.log, probabilities)))

    def _best_candidate(self, prev, without=[]):
        
        blacklist  = [LanguageModel.UNK] + without

        if len(prev) < self.n:
            prev = [LanguageModel.SOS]*(self.n-1)

        candidates = list(((ngram[-1],prob) for ngram,prob in self.model.items() if ngram[:-1]==tuple(prev)))

        probs = [y for x,y in candidates]
        probs = probs/np.sum(probs)
        words = [x for x,y in candidates]

        idx = np.random.choice(len(words), 1, replace=False, p=probs)[0]
        
        while words[idx] in blacklist:
            idx = np.random.choice(len(words), 1, replace=False, p=probs)[0]
        
        return (words[idx], probs[idx])
         
    def generate_sentence(self, input, min_len=12, max_len=24):
        #sent, prob = ([LanguageModel.SOS] * (max(1, self.n-1)), 1)
        sent, prob, start = input.copy(), 1, True
        while sent[-1] != LanguageModel.EOS or start:
            start = False
            prev = () if self.n == 1 else tuple(sent[-(self.n-1):])
            blacklist = sent + ([LanguageModel.EOS,LanguageModel.SOS] if len(sent) < min_len else [])
            next_token, next_prob = self._best_candidate(prev, without=blacklist)
            sent.append(next_token)
            prob *= next_prob

            if len(sent) >= max_len:
                sent.append(LanguageModel.EOS)

        #return (' '.join(sent[(self.n-1):-1]), -1/math.log(prob))
        #return (' '.join(sent), -1/math.log(prob))
        return ' '.join(sent)
    
    

    def add_sentence_tokens(self, sentences, n):
        """Wrap each sentence in SOS and EOS tokens.
        For n >= 2, n-1 SOS tokens are added, otherwise only one is added.
        Args:
            sentences (list of str): the sentences to wrap.
            n (int): order of the n-gram model which will use these sentences.
        Returns:
            List of sentences with SOS and EOS tokens wrapped around them.
        """
        sos = ' '.join([LanguageModel.SOS] * (n-1)) if n > 1 else LanguageModel.SOS
        return ['{} {} {}'.format(sos, s, LanguageModel.EOS) for s in sentences]

    def replace_singletons(self, tokens):
        """Replace tokens which appear only once in the corpus with <UNK>.

        Args:
            tokens (list of str): the tokens comprising the corpus.
        Returns:
            The same list of tokens with each singleton replaced by <UNK>.

        """
        if len(self.vocab) == 0:
            self.vocab = nltk.FreqDist(tokens)
        return [token if self.vocab[token] > 1 else LanguageModel.UNK for token in tokens]

    def preprocess(self, sentences, n):
        """Add SOS/EOS/UNK tokens to given sentences and tokenize.
        Args:
            sentences (list of str): the sentences to preprocess.
            n (int): order of the n-gram model which will use these sentences.
        Returns:
            The preprocessed sentences, tokenized by words.
        """
        sentences = self.add_sentence_tokens(sentences, n)
        tokens = ' '.join(sentences).split()
        tokens = self.replace_singletons(tokens)
        return tokens    

# Encoder - Decoder

In [7]:
class Encoder(nn.Module):
    def __init__(self, embeddings, hidden_dims, num_layers, dropout, device):
        super(Encoder, self).__init__()
        self.embeddings = embeddings
        self.device = device
        self.enc = LSTM(embeddings('ا').size(1), hidden_dims, num_layers, dropout, device)

    def forward(self, input):
        embeds = torch.stack([self.embeddings(s) for s in input]).to(self.device)
        N, L = embeds.size(0), embeds.size(1)
        hidden, cell = self.enc.init_hidden(N)
        output, hidden, cell = self.enc(embeds, N, hidden, cell)
        return output, hidden, cell
    
class Decoder(nn.Module):
    def __init__(self, corpus, embeddings, hidden_dims, output_dims, num_layers, dropout, device):
        super(Decoder, self).__init__()
        self.embeddings = embeddings
        self.corpus = corpus
        self.device = device
        self.dec = LSTM(embeddings('ا').size(1), hidden_dims, num_layers, dropout, device)
        self.output_net = nn.Sequential(
            nn.Linear(hidden_dims,output_dims),
#             nn.ReLU(),
#             nn.Linear(8192,output_dims),
#             nn.ReLU(),
#             nn.Linear(output_dims,output_dims),
            nn.Softmax(dim=1)
        ).to(device)

    def forward(self, target, hidden, cell, teacher_force_prob=1.0):
        target_embeds = torch.stack([self.embeddings(s) for s in target]).to(self.device)
        N, L = target_embeds.size(0), target_embeds.size(1)
        start_token = "__PAD__"
        dec_input = torch.stack([self.embeddings(start_token)]*N)
        dec_hidden, dec_cell = hidden, cell
        tf = False if (teacher_force_prob < 1.0 and random.random() < 1.0-teacher_force_prob) else True
        pred = []
        for l in range(L):
            dec_output, dec_hidden, dec_cell = self.dec(dec_input, N, dec_hidden, dec_cell)
            dec_pred = self.output_net(dec_output.squeeze(dim=1))
            pred.append(dec_pred)
            preds_embs =torch.stack([self.embeddings(self.corpus[idx]) for idx in dec_pred.argmax(dim=1)])
            dec_input = target_embeds[:,l,:].unsqueeze(dim=1) if tf else preds_embs # detach from history as input
        return torch.stack(pred)
        
class EncoderDecoder(nn.Module):
    def __init__(self, corpus, embeddings, hidden_dims, output_dims, num_layers, dropout, device):
        super(EncoderDecoder, self).__init__()
        self.device = device
        self.encoder = Encoder(embeddings, hidden_dims, num_layers, dropout, device)
        self.decoder = Decoder(corpus, embeddings, hidden_dims, output_dims, num_layers, dropout, device)
        self.loss_fn = nn.CrossEntropyLoss(reduction="mean")
        
    def forward(self, input, target, teacher_force_prob=1.0):
        _, enc_hidden, enc_cell = self.encoder(input)
        return self.decoder(target, enc_hidden, enc_cell, teacher_force_prob).transpose(0,1)
    
    def get_loss(self, preds, target_indices, mask):
        #loss = [self.loss_fn(preds[idx].unsqueeze(dim=0), target_indices[idx].unsqueeze(dim=0))*mask[idx] for idx in range(len(mask))]
        loss = self.loss_fn(preds[0:mask,:], target_indices[0:mask].to(self.device))
        return loss

In [11]:
def train_encdec_model(encdec, dataLoader, optim, epochs, device):
    encdec.train()
    for epoch in range(epochs):
        print(f"=============epoch:{epoch}=============")
        for batch_idx, (mesras, (targets, targets_indices, masks)) in enumerate(dataLoader):
            #optim.zero_grad()

            mesras = [m.split("#") for m in mesras]
            targets = [t.split("#") for t in targets]
            targets_indices = targets_indices
            preds = encdec(mesras, targets)
            loss = 0
            for idx, seq in enumerate(preds):
                loss += encdec.get_loss(seq, targets_indices[idx], masks[idx])
            loss.backward()
            if batch_idx%100 ==0:
                print(f"batch {batch_idx} loss: {loss}")
                torch.save(encdec.state_dict(), DRIVE_PATH+f'/encdec_model.pt')
            optim.step()
    
    return encdec

In [None]:
encdec_batch_size = 4
encdec_epochs = 5
encdec_lr = 0.01
encdec_device = 'cpu'

masnaviDataset = MasnaviDataset(masnavis, unigrams)
masnaviDataLoader = D.DataLoader(masnaviDataset, encdec_batch_size)
encdec = EncoderDecoder(corpus=unigrams, embeddings=literatureWord2Vec, hidden_dims=128, output_dims=len(unigrams), num_layers=1, dropout=0.0, device=encdec_device)
encdec_optim = torch.optim.Adam(encdec.parameters(), lr=encdec_lr)
encdec = train_encdec_model(encdec, masnaviDataLoader, encdec_optim, encdec_epochs, encdec_device)

batch 0 loss: 47.63416290283203
batch 1 loss: 47.63416290283203
batch 2 loss: 47.63420104980469
batch 3 loss: 47.634307861328125
batch 4 loss: 47.63398742675781
batch 5 loss: 47.63190460205078
batch 6 loss: 47.626617431640625
batch 7 loss: 47.60655975341797
batch 8 loss: 47.533416748046875
batch 9 loss: 47.368125915527344
batch 10 loss: 47.11322784423828
batch 11 loss: 46.976680755615234
batch 12 loss: 46.937801361083984
batch 13 loss: 46.87327194213867
batch 14 loss: 46.930030822753906
batch 15 loss: 46.95907211303711
batch 16 loss: 46.8367919921875
batch 17 loss: 46.84014892578125
batch 18 loss: 46.87990951538086
batch 19 loss: 46.8219108581543
batch 20 loss: 46.851654052734375
batch 21 loss: 46.91090393066406
batch 22 loss: 46.78603744506836
batch 23 loss: 46.89038848876953
batch 24 loss: 46.85102844238281
batch 25 loss: 46.712928771972656
batch 26 loss: 46.81257629394531
batch 27 loss: 46.647254943847656
batch 28 loss: 46.62928771972656
batch 29 loss: 46.626705169677734


## Rhyme Embedding

### Preparing Rhymes Dataset

In [7]:
NON_RHYME_DATA_POINT = 1 # TODO bayad 1 bashe vagaran eval ro npratio=0.5 bad mishe (ba 3 result khoob bood)

In [8]:
rhymes_dataset = [(w1,w2,1) for (w1,w2) in rhymes]
for w1,w2 in rhymes:
    for idx in range(NON_RHYME_DATA_POINT):
        x = None
        while True:
            x = random.sample(unigrams,1)[0]
            if not do_rhyme_words(w1,x):
                break
        rhymes_dataset.append( (w1,x,-1) )
        while True:
            x = random.sample(unigrams,1)[0]
            if not do_rhyme_words(w2,x):
                break
        rhymes_dataset.append( (x,w2,-1) )

In [9]:
random.sample(rhymes_dataset,5)

[('تعبیر', 'ر', 1),
 ('اولست', 'منازلم', -1),
 ('انداختی', 'برجاستی', 1),
 ('پار', 'یسار', 1),
 ('طور', 'تاهت', -1)]

In [10]:
seed = 42
train_length, test_length = (int)(0.95*len(rhymes_dataset)), len(rhymes_dataset)-(int)(0.95*len(rhymes_dataset))
rhymes_dataset_train, rhymes_dataset_test = torch.utils.data.random_split(rhymes_dataset, [train_length, test_length],torch.Generator().manual_seed(seed))

### Rhymes Model

In [11]:
def run_rhyme_model(rhyme_model, data_loader, batch_size, device, train=True):
    print(f"Running on {device}")
    accs, losses = [], []
    for idx, batch in enumerate(data_loader):
        rhyme_model.optimizer.zero_grad() # too Result khoobe nabood =)))
        
        r1 = torch.vstack([myword2vec(batch[0][b]) for b in range(batch_size)]).to(device)
        r2 = torch.vstack([myword2vec(batch[1][b]) for b in range(batch_size)]).to(device)
        y = torch.Tensor([batch[2][b] for b in range(batch_size)]).to(device)
        embedding1, embedding2 = rhyme_model(r1,r2,y)
        loss = rhyme_model.get_loss(embedding1.view(batch_size,-1), embedding2.view(batch_size,-1),y)
        losses.append(loss)

        y_pred = [1 if l<0.3 else -1 for l in loss]
        acc = sum([int(y[idx]==y_pred[idx]) for idx in range(len(y))])/len(y)
        accs.append(acc)

        if train:
            loss = loss.mean()
            loss.backward()
        
            rhyme_model.optimizer.step()
        
            if idx%1000 == 0:
                print(f"Train Loss: {loss}.         Train Acc: {acc}")
                torch.save(rhyme_model.state_dict(), DRIVE_PATH+f'/rhyme_models/rhyme_model{idx}.pt')
    return accs, losses

In [12]:
#!rm -rf $DRIVE_PATH/rhyme_models/*

#### Train

In [13]:
rhm_batch_size=256
rhm_train_iterations = 20000 # approx 5 epoch
rhm_device = 'cpu'

rhyme_model = SiameseLSTM(embedding_dim=128, hidden_dim=512, num_layers=1, dropout=0.0, learning_rate=0.01, device=rhm_device)
#rhyme_model.load_state_dict(torch.load(DRIVE_PATH+'/rhyme_models/rhyme_model94000.pt'))

rhymes_train_sampler = RhymeBatchSampler(rhymes_dataset_train, npratio=0.5, iterations=train_iterations, batch_size=rhm_batch_size) #npratio=1/(1+NON_RHYME_DATA_POINT)
rhymes_train_data_loader = torch.utils.data.DataLoader(rhymes_dataset_train, batch_sampler=rhymes_train_sampler)

train_accs, train_losses = run_rhyme_model(rhyme_model, rhymes_train_data_loader, rhm_batch_size, rhm_device, train=True)

Running on cpu
Train Loss: 0.5000045299530029.         Train Acc: 1.0


KeyboardInterrupt: 

#### Eval

In [None]:
rhyme_model = SiameseLSTM(embedding_dim=128, hidden_dim=512, num_layers=1, dropout=0.0, learning_rate=0.01, device=rhm_device)
rhymes_test_sampler = RhymeBatchSampler(rhymes_dataset_test, npratio=0.5, iterations=2000, batch_size=rhm_batch_size)
rhymes_test_data_loader = torch.utils.data.DataLoader(rhymes_dataset_test, batch_sampler=rhymes_test_sampler)

rhyme_model.load_state_dict(torch.load(DRIVE_PATH+'/rhyme_models/rhyme_model9000.pt', map_location=torch.device(rhm_device)))

rhyme_model.eval()
eval_accs, eval_loss = run_rhyme_model(rhyme_model, rhymes_test_data_loader, rhm_batch_size, rhm_device, train=False)
print(f"Eval Accuracy: {mean(eval_accs)}")

In [None]:
rhyme_model.predict("سازش","بارش")