In [1]:
!rm -rf wiki_libs
!git clone https://github.com/tsonic/wiki_libs.git

Cloning into 'wiki_libs'...
remote: Enumerating objects: 87, done.[K
remote: Counting objects: 100% (87/87), done.[K
remote: Compressing objects: 100% (60/60), done.[K
remote: Total 87 (delta 44), reused 60 (delta 25), pack-reused 0[K
Unpacking objects: 100% (87/87), done.


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import json
from collections import defaultdict
import pandas as pd
from zipfile import ZipFile
import pickle
import gc
from gensim.models.phrases import Phrases, Phraser
import itertools
import importlib
import wiki_libs
from wiki_libs.preprocessing import *
from functools import partial
import torch.optim as optim
import traceback
importlib.reload(wiki_libs.preprocessing)

use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
# cudnn.benchmark = True # CUDA for PyTorch. Faster runtime if input size constant in iterations.  
!pip install ipdb > /dev/null
!pip install line_profiler > /dev/null
%load_ext line_profiler
import ipdb
import pdb

np.random.seed(12345)

In [3]:
from google.colab import drive
drive.flush_and_unmount()
drive.mount('/content/gdrive',force_remount=True)

Drive not mounted, so nothing to flush and unmount.
Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/gdrive


Configuration Class

In [None]:
%%time
df_cp = read_category_links()

In [None]:
%%time
#load ngram model

ngram_model = load_ngram_model(NGRAM_MODEL_PATH_PREFIX + "title_category_ngram_model.pickle")

In [None]:
print(df_cp.columns)
print(df_cp.shape)

In [None]:
df_cp['page_title'].isnull().mean()

In [None]:
%%time
# tranform title, return list of list of ngrams
title_transformed, category_transformed = get_transformed_title_category(ngram_model)

In [None]:
vocab_title = generate_vocab(title_transformed)

In [None]:
vocab_category = generate_vocab(category_transformed)

In [None]:
vocab_title_category = generate_vocab(title_transformed + category_transformed)

In [None]:
# with open('gdrive/My Drive/Projects with Wei/wiki_data/ngram_model/vocab.json','w') as f:
#   json.dump({
#       'vocab_title':vocab_title,
#       'vocab_category':vocab_category,
#       'vocab_title_category':vocab_title_category,
#       }), f)

In [None]:
with open('gdrive/My Drive/Projects with Wei/wiki_data/ngram_model/vocab.json','r') as f:
    vocab_json = json.load(f)
    vocab_title = vocab_json['vocab_title']
    vocab_category = vocab_json['vocab_category']
    vocab_title_category = 'vocab_title_category'
    word2ind_title_category = defaultdict(lambda:-1, {w:i for i, w in enumerate(vocab_title_category)})
    word2ind_title = defaultdict(lambda:-1, {w:i for i, w in enumerate(vocab_title)})
    word2ind_category = defaultdict(lambda:-1, {w:i for i, w in enumerate(vocab_category)})

In [None]:
len(vocab_title)


In [None]:
def ngram_to_idx(l, word2ind):
  return np.array([word2ind[ng] for ng in l])

In [None]:
%%time
df_page_2_bow = (
    df_cp[['page_id','page_title']]
    .drop_duplicates()
    .fillna({'page_title':''})
    .assign(processed_title = lambda df: df['page_title'].apply(process_title))
    .assign(processed_title_ngram = lambda df: 
            transform_ngram(df['processed_title'].tolist(), ngram_model))
    .assign(processed_title_ngram_idx = lambda df: 
            df['processed_title_ngram'].apply(ngram_to_idx, word2ind=word2ind_title))
)

In [None]:
%%time
df_category_2_bow = (
    df_cp[['page_category']]
    .drop_duplicates()
    .fillna({'page_category':''})
    .assign(processed_category = lambda df: df['page_category']
                                      .apply(process_title))
    .assign(processed_category_ngram = lambda df: transform_ngram(df['processed_category'].tolist(), ngram_model))
    .assign(processed_category_ngram_idx = 
                                     lambda df: df['processed_category_ngram'].apply(ngram_to_idx, word2ind=word2ind_category))
)

In [None]:
%%time
df_lp = next(read_link_pairs_chunks(n_chunk=10))
df_lp.shape

In [None]:
class Config():
    #TODO: fill in the dirs
    training_dir = \
    testing_dir = \

    #TODO: complete hyper-par defs
    EMD_DIM1      = 2 ** 3 #left tower 
    EMD_DIM2      = 2 ** 3 #right tower (= left for L2 cals)
    MAX_EPOCHS    = 100
    LEARNING_RATE = 1 ** (-5) # grid \in [-6:-3]
    NUM_EPOCHS     = 10

    # data parameters. to be used in the DataLoader caclass
    data_params = {'BATCH_SIZE': 2 ** 6, # grid \in [3:8]
                   'shuffle': True,
                   'num_workers': 6}


Helper_functions

In [None]:
def show_plot(iteration, loss):
    plt.plot(iteration, loss)
    plt.show()

**Import** data

Reference: https://stanford.edu/~shervine/blog/pytorch-how-to-generate-data-parallel

In [4]:
class PageWordStats(object):
    def __init__(self, read_path,
                 ngram_model_file = "title_category_ngram_model.pickle",
                 output_path = "gdrive/My Drive/Projects with Wei/wiki_data/page_word_stats.json",
                 n_chunk = 10,
                 ):
        if read_path is not None:
            config = json.load(open(read_path, 'r'))
            self.word_frequency = config['word_frequency']
            self.word2id = config['word2id']
            self.id2word = config['id2word']
            self.page_frequency = {int(k):v for k, v in config['page_frequency'].items()}

            self.page2id = {int(k):v for k, v in config['page2id'].items()}
            self.id2page = config['id2page']
        else:
            # recomptue the target page stats

            gen = read_link_pairs_chunks(n_chunk = n_chunk)
            print('generating page id stats...')
            s_page = []
            for df_chunk in gen:
                val_counts = df_chunk['page_id_target'].append(df_chunk['page_id_source']).value_counts()
                s_page.append(val_counts)
            df_stats = (
                pd.concat(s_page)
                .rename_axis("page_id")  #rename index
                .to_frame('count')
                .groupby("page_id")
                    .sum()
                .reset_index()
            )
            self.page_frequency = {row.page_id: row.count 
                                   for row in df_stats.itertuples(index = False)}
            self.page2id = {p:i for i, p in enumerate(df_stats['page_id'])}
            self.id2page = df_stats['page_id'].tolist()

            # recompute the word stats
            print('generating word/ngram stats from title and categories...')

            ngram_model = load_ngram_model(NGRAM_MODEL_PATH_PREFIX + ngram_model_file)
            title_transformed, category_transformed = get_transformed_title_category(ngram_model)
            s_words = (
                pd.Series(itertools.chain(*(title_transformed + category_transformed)))
                .value_counts()
            )
            # word_frequency is a list, where ith element is the word frequency of word with id i.
            self.word_frequency = s_words.tolist()
            self.word2id = {w:i for i, w in enumerate(s_words.index)}
            self.id2word = s_words.index.tolist()
            json.dump({
                'word_frequency': self.word_frequency,
                'word2id': self.word2id,
                'id2word':self.id2word,
                'page_frequency': self.page_frequency,
                'page2id': self.page2id,
                'id2page': self.id2page,
                }, 
                open(output_path, 'w'))
        print('There are %d unique words/ngrams' % len(self.word2id))
        print('There are %d unique pages' % len(self.page2id))

In [5]:
# %%time
# page_word_stats = PageWordStats(read_path = None,n_chunk=10)

In [6]:
page_word_stats = PageWordStats(read_path = "gdrive/My Drive/Projects with Wei/wiki_data/page_word_stats.json")

There are 1748542 unique words/ngrams
There are 5330812 unique pages


In [7]:
NEGATIVE_TABLE_SIZE = 1e8
class WikiDataset(torch.utils.data.IterableDataset):
    'Characterizes a dataset for PyTorch'
    def __init__(self, file_list, compression, n_chunk, num_negs, page_word_stats, ns_exponent):
        'Initialization'
        #self.labels = labels
        # self.list_IDs = list_IDs
        # if isinstance(file_list, str):
        #     file_list = [file_list]
        self.file_list = file_list
        self.compression = compression
        self.n_chunk = n_chunk
        self.pos = 0

        self.chunk_iterator = None
        self.instance_dict = None


        self.negatives = []
        self.negpos = 0
        self.num_negs = num_negs
        self.page_word_stats = page_word_stats

        self.word_frequency = page_word_stats.word_frequency
        self.word2id = page_word_stats.word2id
        self.id2word = page_word_stats.id2word
        self.page2id = page_word_stats.page2id
        p,i = zip(*self.page2id.items())
        self.page2id_series_map = pd.Series(i, index = p, dtype = np.int64)
        self.id2page = page_word_stats.id2page
        self.page_frequency = page_word_stats.page_frequency
        self.initTableNegatives(ns_exponent=ns_exponent)

    # Iterable may not know the length of the stream before hand
    # def __len__(self):
    #     'Denotes the total number of samples'
    #     return len(self.list_IDs)

    # def __getitem__(self, index):
    #     'Generates one sample of data'
    #     # Select sample
    #     ID = self.list_IDs[index]

    #     # Load data and get label
    #     X = torch.load('data/' + ID + '.pt')
    #     y = self.labels[ID]

    #     return X, y
    
    def __iter__(self):
        return self
    
    def __next__(self):
        if self.chunk_iterator is None:
            self.chunk_iterator = read_files_in_chunks(self.file_list, compression=self.compression, 
                                                       n_chunk = self.n_chunk, progress_bar = True)
            
        if self.instance_dict is None or self.pos >= len(self.instance_dict):
            #print('read_new_chunk', flush = True)
            df = next(self.chunk_iterator)
            # have to reset self.pos after pull next in above iterator, otherwise 
            # it does not invalidate current instance_dict after iterator is exhausted.
            self.pos = 0

            df = df.assign(
                    page_id_source = lambda df: df['page_id_source'].map(self.page2id_series_map),
                    page_id_target = lambda df: df['page_id_target'].map(self.page2id_series_map),
                )

            self.instance_dict = list(df.itertuples(index = False, name = None))
        ret = self.instance_dict[self.pos]
        self.pos += 1
        return ret

    def getNegatives(self, target, size):  # TODO check equality with target
        response = self.negatives[self.negpos:self.negpos + size]
        # reshuffle negative table if negpos > total neg table size.
        if (self.negpos + size) // len(self.negatives) >= 1:
            # print('reshuffle negative table...')
            np.random.shuffle(self.negatives)
        self.negpos = (self.negpos + size) % len(self.negatives)
        if len(response) != size:
            return np.concatenate((response, self.negatives[0:self.negpos]))
        return response

    def initTableNegatives(self, ns_exponent):
        print('Initializing negative samples', flush=True)
        page_id, page_counts = zip(*self.page_frequency.items())
        ratio = np.array(page_counts).astype(np.float64) ** ns_exponent / sum(page_counts)
        sampled_count = np.round(ratio * NEGATIVE_TABLE_SIZE)

        df = pd.DataFrame.from_records(enumerate(sampled_count))
        # the column 0 is the page id, column 1 is the count of the page
        self.negatives = np.repeat(df[0].astype(np.int64).values, df[1].astype(np.int64).values)
        np.random.shuffle(self.negatives)

    def collate(self,batches):
        negs = self.getNegatives(None, self.num_negs * len(batches)).reshape((len(batches), self.num_negs))
        id_list, positive_list = zip(*batches)

        return torch.LongTensor(id_list), torch.LongTensor(positive_list), torch.from_numpy(negs)

    @staticmethod
    def worker_init_fn(worker_id, file_handle_lists):
        worker_info = torch.utils.data.get_worker_info()
        dataset = worker_info.dataset  # the dataset copy in this worker process
        worker_id = worker_info.id

        np.random.seed(np.random.get_state()[1][0] + worker_id) 
        np.random.shuffle(dataset.negatives)
        dataset.file_list = file_handle_lists[worker_id]

In [8]:
#%pdb

In [9]:
def test():
    tt = get_file_handles_in_zip(LINK_PAIRS_LOCATION)
    for i, d in enumerate(WikiDataset(tt, compression='zip', n_chunk=100, num_negs=5, page_word_stats = page_word_stats, ns_exponent = 0.5)):
        print(d)
        if i > 100:
            break


In [10]:
#%lprun -f WikiDataset.__next__ test()

In [11]:
# # Datasets
# partition = # IDs
# labels = # Labels

# #dataloader -> iterator
# training_set = Dataset(partition['train'], labels)
# validation_set = Dataset(partition['validation'], labels)

# #TODO: load trainset and testset
# training_generator  = DataLoader(trainset, **data_params)
# validation_generator  = DataLoader(testset, **data_params)

Model

In [12]:
class two_tower(nn.Module):
    def __init__(self, vocab_size, embedding_dim, context_size, hidden_dim1, out_dim):
        super(nn.Module, self).__init__()
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.linear1 = nn.Linear(context_size * embedding_dim, hidden_dim1)
        self.linear2 = nn.Linear(hidden_dim1, out_dim)

    def forward_left(self, inputs):
        embeds = self.embeddings(inputs).view((1, -1))
        h = F.relu(self.linear1(embeds))
        out = F.relu(self.linear2(h))
        # log_probs = F.log_softmax(out, dim=1)
        return out

    def forward_right(self, inputs):
        embeds = self.embeddings(inputs).view((1, -1))
        h = F.relu(self.linear1(embeds))
        out = F.relu(self.linear2(h))
        # log_probs = F.log_softmax(out, dim=1)
        return out

    def forward(self, inputs):
        inputs_left, inputs_right = inputs
        return self.forward_left(inputs_left), self.forward_right(inputs_right)

In [13]:
class OneTower(nn.Module):
    def __init__(self, corpus_size, embedding_dim, hidden_dim1, out_dim, sparse, single_layer = False):
        super(OneTower, self).__init__()
        self.embeddings = nn.Embedding(corpus_size, embedding_dim, sparse=sparse)
        self.single_layer = single_layer

        # if single_layer is True, it essentially become w2v model with single hidden layer
        if not self.single_layer:
            self.linear1 = nn.Linear(embedding_dim, hidden_dim1)
            self.linear2 = nn.Linear(hidden_dim1, out_dim)
        self.out_embeddings = nn.Embedding(corpus_size, embedding_dim, sparse=sparse)
        
    def forward(self, pos_u, pos_v, neg_v):
        emb_u = self.embeddings(pos_u)
        emb_v = self.out_embeddings(pos_v)
        emb_neg_v = self.out_embeddings(neg_v)

        if not self.single_layer:
            h1 = F.relu(self.linear1(emb_u))
            h2 = F.relu(self.linear2(h1))
        else:
            h2 = emb_v

        score = torch.sum(torch.mul(h2, emb_v), dim=1)
        score = torch.clamp(score, max=10, min=-10)
        score = -F.logsigmoid(score)


        neg_score = torch.bmm(emb_neg_v, h2.unsqueeze(2)).squeeze()
        neg_score = torch.clamp(neg_score, max=10, min=-10)
        neg_score = -torch.sum(F.logsigmoid(-neg_score), dim=1)

        return torch.mean(score + neg_score)    

In [14]:
#%pdb

In [15]:
class MultipleOptimizer:
    def __init__(self, *op):
        self.optimizers = op

    def zero_grad(self):
        for op in self.optimizers:
            op.zero_grad()

    def step(self):
        for op in self.optimizers:
            op.step()

class WikiTrainer:

    def __init__(self, page_word_stats, hidden_dim1, out_dim, embedding_dim=100, batch_size=32, window_size=5, iterations=3,
                 initial_lr=0.001, min_count=12, num_workers=0, collate_fn='custom', iprint=500, t=1e-3, ns_exponent=0.75, 
                 optimizer='adam', optimizer_kwargs = None, warm_start_model = None, lr_schedule = False, timeout = 60, n_chunk = 20,
                 sparse=False, single_layer=False, test = False, save_embedding = True):

        file_handle_lists = get_files_in_dir(LINK_PAIRS_LOCATION)
        if test:
            file_handle_lists = file_handle_lists[:2]
            num_workers = 0
            n_chunk = 1
            
        dataset = WikiDataset(file_list = None, compression = 'zip', n_chunk = n_chunk, 
                              page_word_stats = page_word_stats, num_negs=5, 
                              ns_exponent=ns_exponent)
        if collate_fn == 'custom':
            collate_fn = dataset.collate
        else:
            collate_fn = None
        
        if num_workers > 0:
            file_handle_lists = np.array_split(file_handle_lists, num_workers)
        else:
            timeout = 0
            dataset.file_list = file_handle_lists

        self.dataloader = DataLoader(dataset, batch_size=batch_size,
                                     shuffle=False, num_workers=num_workers, 
                                     collate_fn=collate_fn, 
                                     worker_init_fn=partial(dataset.worker_init_fn, file_handle_lists=file_handle_lists),
                                     timeout = timeout,
                                    )

        # self.output_file_name = output_file
        self.corpus_size = len(dataset.page_frequency)
        self.embedding_dim = embedding_dim
        self.save_embedding = save_embedding
        self.iprint = iprint
        self.batch_size = batch_size
        self.iterations = iterations
        self.initial_lr = initial_lr
        self.model = OneTower(self.corpus_size, self.embedding_dim, 
 #                             context_size = dataset.num_neg, 
                              hidden_dim1 = hidden_dim1, 
                              out_dim = out_dim, sparse=sparse,
                              single_layer = single_layer,
                              )

        if warm_start_model is not None:
            self.model.load_state_dict(torch.load(warm_start_model), strict=False)
        self.optimizer = optimizer
        if optimizer_kwargs is None:
            optimizer_kwargs = {}
        self.optimizer_kwargs = optimizer_kwargs
        self.lr_schedule = lr_schedule
        self.use_cuda = torch.cuda.is_available()
        self.device = torch.device("cuda" if self.use_cuda else "cpu")
        if self.use_cuda:
            self.model.cuda()

    def train(self):
        # clearn GPU memory cache
        gc.collect()
        torch.cuda.empty_cache()

        if self.optimizer == 'adam':
            optimizer = optim.Adam(self.model.parameters(), lr=self.initial_lr, **self.optimizer_kwargs)
        elif self.optimizer == 'sparse_adam':
            optimizer = optim.SparseAdam(self.model.parameters(), lr=self.initial_lr, **self.optimizer_kwargs)
        elif self.optimizer == 'sparse_dense_adam':
            opti_sparse = optim.SparseAdam([self.model.embeddings.weight, self.model.out_embeddings.weight], lr=self.initial_lr, **self.optimizer_kwargs)
            opti_dense = optim.Adam([self.model.linear1.weight, self.model.linear2.weight], lr=self.initial_lr, **self.optimizer_kwargs)
            optimizer = MultipleOptimizer(opti_sparse, opti_dense)
        elif self.optimizer == 'sgd':
            optimizer = optim.SGD(self.model.parameters(), lr=self.initial_lr, **self.optimizer_kwargs)
        elif self.optimizer == 'asgd':
            optimizer = optim.ASGD(self.model.parameters(), lr=self.initial_lr, **self.optimizer_kwargs)
        elif self.optimizer == 'adagrad':
            optimizer = optim.Adagrad(self.model.parameters(), lr=self.initial_lr, **self.optimizer_kwargs)
        else:
            raise Exception('Unknown optimizer!')

        for iteration in range(self.iterations):

            print("\n\n\nIteration: " + str(iteration + 1))

            if self.lr_schedule:
                scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(self.dataloader))
            running_loss = 0.0
            iprint = self.iprint #len(self.dataloader) // 20
            for i, sample_batched in enumerate(self.dataloader):
                # ipdb.set_trace()
                if len(sample_batched[0]) > 1:
                    pos_u = sample_batched[0].to(self.device)
                    pos_v = sample_batched[1].to(self.device)
                    neg_v = sample_batched[2].to(self.device)
                    
                    optimizer.zero_grad()
                    loss = self.model.forward(pos_u, pos_v, neg_v)
                    loss.backward()
                    optimizer.step()
                    if self.lr_schedule:
                        scheduler.step()

                    running_loss = running_loss * (1 - 5/iprint) + loss.item() * (5/iprint)
                    if i > 0 and i % iprint == 0:
                        # print(" Loss: " + str(running_loss) + ' lr: ' 
                        #     + str([param_group['lr'] for param_group in optimizer.param_groups]))
                        print(" Loss: " + str(running_loss))
            print(" Loss: " + str(running_loss))

        #self.skip_gram_model.save_embedding(self.data.id2word, self.output_file_name)
        if self.save_embedding:
            print('Saving embeddings...', flush=True)
            embeddings = self.model.embeddings.weight.cpu().data.numpy()
            np.savez_compressed('gdrive/My Drive/Projects with Wei/wiki_data/wiki_embedding/embedding.npz', embeddings)

In [16]:
def train_model(epochs = 1, collate_fn='custom',num_workers=1, test = False, n_chunk = 20, save_embedding=True):
  #ipdb.set_trace()
  wt = WikiTrainer(
                        page_word_stats = page_word_stats, 
                        hidden_dim1 = 100, 
                        out_dim = 100, 
                        # output_file="gdrive/My Drive/Projects with Wei/Wei_tmp_outputs/w2v_output/out.vec",
                        # min_count=5,
                        batch_size=8096,
                        iterations=epochs,
                        num_workers=num_workers,
                        collate_fn=collate_fn,
                        iprint=5000,
                        n_chunk = n_chunk,
                        embedding_dim=100,
                        ns_exponent=0.75,
                        initial_lr=0.002,
                        optimizer='sparse_adam',
                        single_layer=True,
                        sparse=True,
                        lr_schedule=False,
                        test=test,
                        save_embedding=save_embedding,
 #                       optimizer_kwargs={'momentum':0.9},
 #                       warm_start_model = 'gdrive/My Drive/Projects with Wei/Wei_tmp_outputs/torch_model/model.pkl',
                        )
  wt.train()
  return wt

In [17]:
gc.collect()
torch.cuda.empty_cache()

In [18]:
train_model(epochs=20,
            num_workers=3, 
            test = False, 
            save_embedding=True)

Initializing negative samples



Iteration: 1


  5%|▌         | 1/20 [01:10<22:10, 70.03s/it]

 Loss: 14.705712729023572


 15%|█▌        | 3/20 [03:29<19:46, 69.78s/it]

 Loss: 14.020082445807551


 25%|██▌       | 5/20 [05:49<17:27, 69.80s/it]

 Loss: 13.48529598702027


 35%|███▌      | 7/20 [08:14<15:28, 71.42s/it]

 Loss: 13.031981582442004


 50%|█████     | 10/20 [10:52<10:17, 61.72s/it]

 Loss: 12.620547089598384


 60%|██████    | 12/20 [13:09<08:04, 60.58s/it]

 Loss: 12.238354696553483


 70%|███████   | 14/20 [15:02<05:50, 58.39s/it]

 Loss: 11.878915310497542


 85%|████████▌ | 17/20 [17:42<02:58, 59.56s/it]

 Loss: 11.534206094077579


 95%|█████████▌| 19/20 [20:05<01:00, 60.34s/it]

 Loss: 11.20488226139866


100%|██████████| 20/20 [20:42<00:00, 62.13s/it]
100%|██████████| 20/20 [20:58<00:00, 62.91s/it]
100%|██████████| 20/20 [20:58<00:00, 62.93s/it]


 Loss: 11.130658483844984



Iteration: 2


  5%|▌         | 1/20 [01:08<21:46, 68.74s/it]

 Loss: 9.12624919620048


 15%|█▌        | 3/20 [03:45<20:50, 73.55s/it]

 Loss: 9.338702698246882


 25%|██▌       | 5/20 [06:22<19:00, 76.03s/it]

 Loss: 9.28243801579988


 35%|███▌      | 7/20 [09:11<17:31, 80.91s/it]

 Loss: 9.159208511143031


 50%|█████     | 10/20 [12:08<11:32, 69.26s/it]

 Loss: 9.003360168169927


 60%|██████    | 12/20 [14:32<08:46, 65.78s/it]

 Loss: 8.834916685606775


 70%|███████   | 14/20 [16:35<06:21, 63.54s/it]

 Loss: 8.661599498036727


 85%|████████▌ | 17/20 [19:22<03:08, 62.86s/it]

 Loss: 8.488044523666446


 95%|█████████▌| 19/20 [21:57<01:04, 64.76s/it]

 Loss: 8.315058143360355


100%|██████████| 20/20 [22:38<00:00, 67.95s/it]
100%|██████████| 20/20 [22:54<00:00, 68.73s/it]
100%|██████████| 20/20 [22:55<00:00, 68.75s/it]


 Loss: 8.272815172636834



Iteration: 3


  5%|▌         | 1/20 [01:08<21:46, 68.78s/it]

 Loss: 7.042097638553477


 15%|█▌        | 3/20 [03:46<20:54, 73.80s/it]

 Loss: 7.186788740361588


 25%|██▌       | 5/20 [06:22<19:00, 76.02s/it]

 Loss: 7.16451801810809


 35%|███▌      | 7/20 [09:13<17:35, 81.18s/it]

 Loss: 7.104052870718289


 50%|█████     | 10/20 [12:10<11:32, 69.26s/it]

 Loss: 7.024983399651482


 60%|██████    | 12/20 [14:33<08:47, 65.94s/it]

 Loss: 6.937885095441418


 70%|███████   | 14/20 [16:36<06:21, 63.51s/it]

 Loss: 6.848280277620312


 85%|████████▌ | 17/20 [19:24<03:08, 62.95s/it]

 Loss: 6.757251297398909


 95%|█████████▌| 19/20 [21:58<01:04, 64.77s/it]

 Loss: 6.665298098393561


100%|██████████| 20/20 [22:40<00:00, 68.04s/it]
100%|██████████| 20/20 [22:55<00:00, 68.80s/it]
100%|██████████| 20/20 [22:56<00:00, 68.81s/it]


 Loss: 6.640984955353588



Iteration: 4


  5%|▌         | 1/20 [01:08<21:41, 68.48s/it]

 Loss: 5.842159752767573


 15%|█▌        | 3/20 [03:45<20:48, 73.46s/it]

 Loss: 5.952811795620117


 25%|██▌       | 5/20 [06:22<19:01, 76.08s/it]

 Loss: 5.949601193402164


 35%|███▌      | 7/20 [09:12<17:33, 81.02s/it]

 Loss: 5.921586902032292


 50%|█████     | 10/20 [12:09<11:33, 69.38s/it]

 Loss: 5.881117123147254


 60%|██████    | 12/20 [14:33<08:47, 65.92s/it]

 Loss: 5.834848915180094


 70%|███████   | 14/20 [16:37<06:21, 63.66s/it]

 Loss: 5.7863427410066235


 85%|████████▌ | 17/20 [19:24<03:08, 62.91s/it]

 Loss: 5.735112263829685


 95%|█████████▌| 19/20 [21:59<01:04, 64.94s/it]

 Loss: 5.682560427720571


100%|██████████| 20/20 [22:40<00:00, 68.04s/it]
100%|██████████| 20/20 [22:56<00:00, 68.81s/it]
100%|██████████| 20/20 [22:56<00:00, 68.84s/it]


 Loss: 5.6671948555324345



Iteration: 5


  5%|▌         | 1/20 [01:08<21:36, 68.22s/it]

 Loss: 5.1187579103088465


 15%|█▌        | 3/20 [03:45<20:47, 73.41s/it]

 Loss: 5.205477005739818


 25%|██▌       | 5/20 [06:22<18:58, 75.89s/it]

 Loss: 5.207859828353313


 35%|███▌      | 7/20 [09:01<16:50, 77.70s/it]

 Loss: 5.193493870212259


 50%|█████     | 10/20 [11:34<10:26, 62.61s/it]

 Loss: 5.170238957232028


 60%|██████    | 12/20 [13:43<07:56, 59.60s/it]

 Loss: 5.143098596937443


 70%|███████   | 14/20 [15:36<05:47, 57.85s/it]

 Loss: 5.114118693820992


 85%|████████▌ | 17/20 [17:58<02:45, 55.07s/it]

 Loss: 5.0823021309002625


 95%|█████████▌| 19/20 [20:17<00:57, 57.60s/it]

 Loss: 5.049486959766672


100%|██████████| 20/20 [20:53<00:00, 62.69s/it]
100%|██████████| 20/20 [21:09<00:00, 63.48s/it]
100%|██████████| 20/20 [21:10<00:00, 63.50s/it]


 Loss: 5.038798552088651



Iteration: 6


  5%|▌         | 1/20 [01:08<21:39, 68.39s/it]

 Loss: 4.644592556280059


 15%|█▌        | 3/20 [03:45<20:50, 73.56s/it]

 Loss: 4.714480143755549


 25%|██▌       | 5/20 [06:22<18:59, 75.99s/it]

 Loss: 4.717834830680678


 35%|███▌      | 7/20 [09:14<17:38, 81.43s/it]

 Loss: 4.709556247940599


 50%|█████     | 10/20 [12:12<11:36, 69.69s/it]

 Loss: 4.694848571653949


 60%|██████    | 12/20 [14:36<08:48, 66.11s/it]

 Loss: 4.677890212219312


 70%|███████   | 14/20 [16:39<06:21, 63.62s/it]

 Loss: 4.658941678280112


 85%|████████▌ | 17/20 [19:29<03:10, 63.52s/it]

 Loss: 4.6376285979320135


 95%|█████████▌| 19/20 [22:03<01:04, 64.90s/it]

 Loss: 4.615689894999125


100%|██████████| 20/20 [22:44<00:00, 68.24s/it]
100%|██████████| 20/20 [23:00<00:00, 69.02s/it]
100%|██████████| 20/20 [23:00<00:00, 69.04s/it]


 Loss: 4.607656718658163



Iteration: 7


  5%|▌         | 1/20 [01:08<21:45, 68.72s/it]

 Loss: 4.313271718963405


 15%|█▌        | 3/20 [03:45<20:51, 73.62s/it]

 Loss: 4.371332271032591


 25%|██▌       | 5/20 [06:22<19:01, 76.12s/it]

 Loss: 4.374425640528766


 35%|███▌      | 7/20 [09:12<17:33, 81.03s/it]

 Loss: 4.369206503169216


 50%|█████     | 10/20 [12:10<11:34, 69.47s/it]

 Loss: 4.359030220640019


 60%|██████    | 12/20 [14:32<08:46, 65.80s/it]

 Loss: 4.348024665145185


 70%|███████   | 14/20 [16:36<06:22, 63.67s/it]

 Loss: 4.334569747803158


 85%|████████▌ | 17/20 [19:24<03:08, 62.91s/it]

 Loss: 4.319498697172901


 95%|█████████▌| 19/20 [21:57<01:04, 64.77s/it]

 Loss: 4.30395720034262


100%|██████████| 20/20 [22:40<00:00, 68.02s/it]
100%|██████████| 20/20 [22:55<00:00, 68.79s/it]
100%|██████████| 20/20 [22:55<00:00, 68.78s/it]


 Loss: 4.297587494882682



Iteration: 8


  5%|▌         | 1/20 [01:07<21:26, 67.69s/it]

 Loss: 4.0699991454419235


 15%|█▌        | 3/20 [03:44<20:41, 73.04s/it]

 Loss: 4.119761711991808


 25%|██▌       | 5/20 [06:21<18:55, 75.67s/it]

 Loss: 4.122288783690033


 35%|███▌      | 7/20 [09:11<17:33, 81.00s/it]

 Loss: 4.11873333952001


 50%|█████     | 10/20 [12:10<11:36, 69.69s/it]

 Loss: 4.111189101242069


 60%|██████    | 12/20 [14:35<08:51, 66.48s/it]

 Loss: 4.103961294703394


 70%|███████   | 14/20 [16:39<06:23, 63.97s/it]

 Loss: 4.09375037126357


 85%|████████▌ | 17/20 [19:23<03:07, 62.55s/it]

 Loss: 4.082711400733162


 95%|█████████▌| 19/20 [22:02<01:05, 65.84s/it]

 Loss: 4.071230416854385


100%|██████████| 20/20 [22:45<00:00, 68.29s/it]
100%|██████████| 20/20 [23:01<00:00, 69.10s/it]
100%|██████████| 20/20 [23:01<00:00, 69.07s/it]


 Loss: 4.065997625632889



Iteration: 9


  5%|▌         | 1/20 [01:08<21:40, 68.45s/it]

 Loss: 3.884454569944287


 15%|█▌        | 3/20 [03:25<19:23, 68.46s/it]

 Loss: 3.9283630135122043


 25%|██▌       | 5/20 [05:42<17:05, 68.40s/it]

 Loss: 3.930414395033108


 35%|███▌      | 7/20 [08:05<15:11, 70.08s/it]

 Loss: 3.927877032886145


 50%|█████     | 10/20 [10:40<10:04, 60.50s/it]

 Loss: 3.9220902803930318


 60%|██████    | 12/20 [12:57<08:04, 60.51s/it]

 Loss: 3.9173689490751364


 70%|███████   | 14/20 [14:55<05:57, 59.51s/it]

 Loss: 3.9093340886611765


 85%|████████▌ | 17/20 [17:24<02:50, 56.86s/it]

 Loss: 3.9010542952951064


 95%|█████████▌| 19/20 [19:39<00:57, 57.63s/it]

 Loss: 3.8923708715443732


100%|██████████| 20/20 [20:15<00:00, 60.79s/it]
100%|██████████| 20/20 [20:30<00:00, 61.53s/it]
100%|██████████| 20/20 [20:30<00:00, 61.53s/it]


 Loss: 3.887971512609476



Iteration: 10


  5%|▌         | 1/20 [01:08<21:50, 68.98s/it]

 Loss: 3.7390392543998914


 15%|█▌        | 3/20 [03:25<19:26, 68.60s/it]

 Loss: 3.7787746968150953


 25%|██▌       | 5/20 [05:42<17:04, 68.28s/it]

 Loss: 3.780522709359464


 35%|███▌      | 7/20 [08:05<15:11, 70.13s/it]

 Loss: 3.7786796294615055


 50%|█████     | 10/20 [10:39<10:02, 60.29s/it]

 Loss: 3.774190177532299


 60%|██████    | 12/20 [12:55<07:59, 59.94s/it]

 Loss: 3.7711675090788055


 70%|███████   | 14/20 [14:52<05:54, 59.14s/it]

 Loss: 3.7647330176076657


 85%|████████▌ | 17/20 [17:20<02:50, 56.72s/it]

 Loss: 3.7584649051639207


 95%|█████████▌| 19/20 [19:36<00:57, 57.61s/it]

 Loss: 3.751780385819096


100%|██████████| 20/20 [20:11<00:00, 60.57s/it]
100%|██████████| 20/20 [20:26<00:00, 61.33s/it]
100%|██████████| 20/20 [20:27<00:00, 61.35s/it]


 Loss: 3.748010270936833



Iteration: 11


  5%|▌         | 1/20 [01:08<21:33, 68.09s/it]

 Loss: 3.6227681649346315


 15%|█▌        | 3/20 [03:24<19:16, 68.01s/it]

 Loss: 3.6594491591137324


 25%|██▌       | 5/20 [05:40<17:00, 68.03s/it]

 Loss: 3.6610024450516683


 35%|███▌      | 7/20 [08:03<15:07, 69.84s/it]

 Loss: 3.6596596798270777


 50%|█████     | 10/20 [10:37<10:01, 60.13s/it]

 Loss: 3.6561857554490205


 60%|██████    | 12/20 [12:53<07:59, 59.94s/it]

 Loss: 3.6543478247365906


 70%|███████   | 14/20 [14:49<05:53, 58.98s/it]

 Loss: 3.6490875128827818


 85%|████████▌ | 17/20 [17:16<02:49, 56.49s/it]

 Loss: 3.6443201754672625


 95%|█████████▌| 19/20 [19:32<00:57, 57.28s/it]

 Loss: 3.6390925339637326


100%|██████████| 20/20 [20:07<00:00, 60.36s/it]
100%|██████████| 20/20 [20:22<00:00, 61.13s/it]
100%|██████████| 20/20 [20:22<00:00, 61.14s/it]


 Loss: 3.635791598975788



Iteration: 12


  5%|▌         | 1/20 [01:07<21:25, 67.68s/it]

 Loss: 3.528123135737171


 15%|█▌        | 3/20 [03:24<19:14, 67.91s/it]

 Loss: 3.5624631792932644


 25%|██▌       | 5/20 [05:39<16:57, 67.84s/it]

 Loss: 3.563900690501498


 35%|███▌      | 7/20 [08:02<15:06, 69.71s/it]

 Loss: 3.562948383140061


 50%|█████     | 10/20 [10:36<10:00, 60.09s/it]

 Loss: 3.560232144712334


 60%|██████    | 12/20 [12:51<07:58, 59.75s/it]

 Loss: 3.559224700837535


 70%|███████   | 14/20 [14:48<05:53, 58.95s/it]

 Loss: 3.5548721345209193


 85%|████████▌ | 17/20 [17:15<02:49, 56.52s/it]

 Loss: 3.551192268423351


 95%|█████████▌| 19/20 [19:32<00:57, 57.54s/it]

 Loss: 3.547065936704818


100%|██████████| 20/20 [20:07<00:00, 60.36s/it]
100%|██████████| 20/20 [20:22<00:00, 61.14s/it]
100%|██████████| 20/20 [20:22<00:00, 61.13s/it]


 Loss: 3.544112148355266



Iteration: 13


  5%|▌         | 1/20 [01:08<21:34, 68.14s/it]

 Loss: 3.4497057025944513


 15%|█▌        | 3/20 [03:24<19:17, 68.09s/it]

 Loss: 3.4822049524679275


 25%|██▌       | 5/20 [05:40<16:59, 67.98s/it]

 Loss: 3.4835767997335028


 35%|███▌      | 7/20 [08:02<15:04, 69.57s/it]

 Loss: 3.482941063451076


 50%|█████     | 10/20 [10:36<10:02, 60.21s/it]

 Loss: 3.480798077302176


 60%|██████    | 12/20 [12:52<07:59, 59.89s/it]

 Loss: 3.4803556010458894


 70%|███████   | 14/20 [14:49<05:54, 59.09s/it]

 Loss: 3.476703075997805


 85%|████████▌ | 17/20 [17:17<02:49, 56.60s/it]

 Loss: 3.473828143512969


 95%|█████████▌| 19/20 [19:34<00:57, 57.70s/it]

 Loss: 3.4705269835812556


100%|██████████| 20/20 [20:09<00:00, 60.47s/it]
100%|██████████| 20/20 [20:24<00:00, 61.23s/it]
100%|██████████| 20/20 [20:24<00:00, 61.25s/it]


 Loss: 3.467818291978665



Iteration: 14


  5%|▌         | 1/20 [01:08<21:40, 68.47s/it]

 Loss: 3.3836038713179324


 15%|█▌        | 3/20 [03:24<19:20, 68.29s/it]

 Loss: 3.414648836482624


 25%|██▌       | 5/20 [05:40<17:00, 68.05s/it]

 Loss: 3.4159498983390066


 35%|███▌      | 7/20 [08:02<15:04, 69.60s/it]

 Loss: 3.4155601151403703


 50%|█████     | 10/20 [10:38<10:05, 60.59s/it]

 Loss: 3.4138412517817893


 60%|██████    | 12/20 [12:54<08:00, 60.12s/it]

 Loss: 3.4137877277297193


 70%|███████   | 14/20 [14:51<05:56, 59.34s/it]

 Loss: 3.410678148730665


 85%|████████▌ | 17/20 [17:19<02:50, 56.81s/it]

 Loss: 3.408392731981475


 95%|█████████▌| 19/20 [19:36<00:57, 57.57s/it]

 Loss: 3.4057339981761823


100%|██████████| 20/20 [20:11<00:00, 60.56s/it]
100%|██████████| 20/20 [20:26<00:00, 61.31s/it]
100%|██████████| 20/20 [20:26<00:00, 61.32s/it]


 Loss: 3.403204565152367



Iteration: 15


  5%|▌         | 1/20 [01:07<21:28, 67.84s/it]

 Loss: 3.326941695221579


 15%|█▌        | 3/20 [03:24<19:15, 67.97s/it]

 Loss: 3.3568218776542045


 25%|██▌       | 5/20 [05:40<17:00, 68.05s/it]

 Loss: 3.358046854997448


 35%|███▌      | 7/20 [08:02<15:05, 69.69s/it]

 Loss: 3.3578278265243253


 50%|█████     | 10/20 [10:37<10:02, 60.20s/it]

 Loss: 3.356443563947259


 60%|██████    | 12/20 [12:52<07:58, 59.80s/it]

 Loss: 3.356644670138137


 70%|███████   | 14/20 [14:50<05:54, 59.16s/it]

 Loss: 3.353955246815088


 85%|████████▌ | 17/20 [17:18<02:50, 56.72s/it]

 Loss: 3.3521165520120895


 95%|█████████▌| 19/20 [19:35<00:57, 57.76s/it]

 Loss: 3.3499424004266736


100%|██████████| 20/20 [20:10<00:00, 60.51s/it]
100%|██████████| 20/20 [20:26<00:00, 61.31s/it]
100%|██████████| 20/20 [20:25<00:00, 61.28s/it]


 Loss: 3.3475573725551273



Iteration: 16


  5%|▌         | 1/20 [01:07<21:31, 67.96s/it]

 Loss: 3.277597748256266


 15%|█▌        | 3/20 [03:24<19:16, 68.02s/it]

 Loss: 3.306498635169548


 25%|██▌       | 5/20 [05:40<17:00, 68.04s/it]

 Loss: 3.307676682875085


 35%|███▌      | 7/20 [08:03<15:08, 69.91s/it]

 Loss: 3.307572123391465


 50%|█████     | 10/20 [10:37<10:02, 60.25s/it]

 Loss: 3.3064429217389826


 60%|██████    | 12/20 [12:53<07:59, 59.93s/it]

 Loss: 3.3068266826434045


 70%|███████   | 14/20 [14:50<05:54, 59.17s/it]

 Loss: 3.3044677815089942


 85%|████████▌ | 17/20 [17:18<02:50, 56.75s/it]

 Loss: 3.302966075636966


 95%|█████████▌| 19/20 [19:35<00:57, 57.61s/it]

 Loss: 3.301175478688527


100%|██████████| 20/20 [20:10<00:00, 60.53s/it]
100%|██████████| 20/20 [20:25<00:00, 61.30s/it]
100%|██████████| 20/20 [20:26<00:00, 61.31s/it]


 Loss: 3.298899782473207



Iteration: 17


  5%|▌         | 1/20 [01:08<21:39, 68.42s/it]

 Loss: 3.233995855857122


 15%|█▌        | 3/20 [03:25<19:21, 68.33s/it]

 Loss: 3.2620874664168347


 25%|██▌       | 5/20 [05:41<17:02, 68.17s/it]

 Loss: 3.2632055347693107


 35%|███▌      | 7/20 [08:02<15:05, 69.62s/it]

 Loss: 3.26318471598161


 50%|█████     | 10/20 [10:39<10:06, 60.67s/it]

 Loss: 3.2622509976092076


 60%|██████    | 12/20 [12:55<08:02, 60.29s/it]

 Loss: 3.262762270344283


 70%|███████   | 14/20 [14:52<05:56, 59.34s/it]

 Loss: 3.260677247422461


 85%|████████▌ | 17/20 [17:21<02:50, 56.82s/it]

 Loss: 3.259420937755282


 95%|█████████▌| 19/20 [19:37<00:57, 57.67s/it]

 Loss: 3.2579365865705614


100%|██████████| 20/20 [20:13<00:00, 60.65s/it]
100%|██████████| 20/20 [20:28<00:00, 61.40s/it]
100%|██████████| 20/20 [20:27<00:00, 61.40s/it]


 Loss: 3.2557532145275676



Iteration: 18


  5%|▌         | 1/20 [01:08<21:34, 68.14s/it]

 Loss: 3.1949508511672247


 15%|█▌        | 3/20 [03:25<19:20, 68.25s/it]

 Loss: 3.222365911716063


 25%|██▌       | 5/20 [05:41<17:02, 68.14s/it]

 Loss: 3.223432171768583


 35%|███▌      | 7/20 [08:02<15:04, 69.61s/it]

 Loss: 3.2234680765376416


 50%|█████     | 10/20 [10:37<10:02, 60.22s/it]

 Loss: 3.2226817197554887


 60%|██████    | 12/20 [12:53<07:58, 59.84s/it]

 Loss: 3.2232858922931324


 70%|███████   | 14/20 [14:50<05:55, 59.17s/it]

 Loss: 3.221424519199263


 85%|████████▌ | 17/20 [17:18<02:50, 56.80s/it]

 Loss: 3.22035168659241


 95%|█████████▌| 19/20 [19:35<00:57, 57.79s/it]

 Loss: 3.2191214573126863


100%|██████████| 20/20 [20:11<00:00, 60.56s/it]
100%|██████████| 20/20 [20:26<00:00, 61.31s/it]
100%|██████████| 20/20 [20:26<00:00, 61.32s/it]


 Loss: 3.217002663022428



Iteration: 19


  5%|▌         | 1/20 [01:08<21:37, 68.28s/it]

 Loss: 3.1595972913093155


 15%|█▌        | 3/20 [03:25<19:22, 68.37s/it]

 Loss: 3.1864244031905327


 25%|██▌       | 5/20 [05:42<17:05, 68.40s/it]

 Loss: 3.1874339044785276


 35%|███▌      | 7/20 [08:06<15:12, 70.21s/it]

 Loss: 3.1875104699771737


 50%|█████     | 10/20 [10:40<10:04, 60.49s/it]

 Loss: 3.1868382538513393


 60%|██████    | 12/20 [12:57<08:01, 60.14s/it]

 Loss: 3.1875091254988877


 70%|███████   | 14/20 [14:55<05:56, 59.41s/it]

 Loss: 3.1858302756895682


 85%|████████▌ | 17/20 [17:23<02:50, 56.88s/it]

 Loss: 3.184887123200201


 95%|█████████▌| 19/20 [19:40<00:57, 57.71s/it]

 Loss: 3.183873407062621


100%|██████████| 20/20 [20:15<00:00, 60.75s/it]
100%|██████████| 20/20 [20:30<00:00, 61.53s/it]
100%|██████████| 20/20 [20:30<00:00, 61.53s/it]


 Loss: 3.1818007088378564



Iteration: 20


  5%|▌         | 1/20 [01:08<21:32, 68.00s/it]

 Loss: 3.127247406029966


 15%|█▌        | 3/20 [03:25<19:20, 68.28s/it]

 Loss: 3.153551753189505


 25%|██▌       | 5/20 [05:41<17:02, 68.15s/it]

 Loss: 3.15449887827236


 35%|███▌      | 7/20 [08:05<15:11, 70.09s/it]

 Loss: 3.1546029734855567


 50%|█████     | 10/20 [10:39<10:04, 60.47s/it]

 Loss: 3.15402674573876


 60%|██████    | 12/20 [12:55<07:59, 59.92s/it]

 Loss: 3.1547305563399397


 70%|███████   | 14/20 [14:52<05:54, 59.10s/it]

 Loss: 3.15320273092764


 85%|████████▌ | 17/20 [17:19<02:49, 56.60s/it]

 Loss: 3.152353044738457


 95%|█████████▌| 19/20 [19:36<00:57, 57.42s/it]

 Loss: 3.151514125973728


100%|██████████| 20/20 [20:10<00:00, 60.54s/it]
100%|██████████| 20/20 [20:26<00:00, 61.31s/it]
100%|██████████| 20/20 [20:26<00:00, 61.34s/it]


 Loss: 3.1494753585896036
Saving embeddings...


<__main__.WikiTrainer at 0x7f9c80bb6be0>

Train

In [None]:
%lprun -f WikiDataset.__next__ -f WikiDataset.collate -f WikiTrainer.train train_model(num_workers=0, test = True, n_chunk=1)

Initializing negative samples



  0%|          | 0/1 [00:00<?, ?it/s][A




Iteration: 1



100%|██████████| 1/1 [00:58<00:00, 58.34s/it]


 Loss: 5.873366421972997


In [None]:
length = 0
for c in read_link_pairs_chunks(n_chunk=10):
    print(len(c))
    length += len(c)
print(length)

  0%|          | 0/10 [00:00<?, ?it/s]

reading link pairs in 10 chunks


 10%|█         | 1/10 [00:11<01:46, 11.87s/it]

37397160


 20%|██        | 2/10 [00:22<01:32, 11.62s/it]

37397160


 30%|███       | 3/10 [00:34<01:21, 11.61s/it]

37397160


 40%|████      | 4/10 [00:45<01:08, 11.36s/it]

37397160


 50%|█████     | 5/10 [00:56<00:56, 11.20s/it]

37397160


 60%|██████    | 6/10 [01:06<00:44, 11.02s/it]

37397160


 70%|███████   | 7/10 [01:17<00:32, 10.86s/it]

37397160


 80%|████████  | 8/10 [01:27<00:21, 10.69s/it]

37397160


 90%|█████████ | 9/10 [01:37<00:10, 10.56s/it]

37397160


100%|██████████| 10/10 [01:48<00:00, 10.80s/it]

37397113
373971553





In [None]:
class ContrastiveLoss(torch.nn.Module):
    """
    Contrastive loss function.
    Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
    """

    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, outputs, label):
        outputs1, output2 = outputs
        euclidean_distance = F.pairwise_distance(output1, output2)
        loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
                                      (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
        return loss_contrastive

We'll define a little function to create our model and optimizer so we
can reuse it in the future.

In [None]:
def get_model(C, device):
    model = two_tower().to(device)
    return model, optim.Adam(model.parameters(), lr = C.LEARNING_RATE)

Since we go through a similar process twice of calculating the loss for both the training set and the validation set, let's make that into its own function, loss_batch, which computes the loss for one batch.

We pass an optimizer in for the training set, and use it to perform backprop. For the validation set, we don't pass an optimizer, so the method doesn't perform backprop.

In [None]:
def loss_batch(model, loss_func, xb, yb, opt=None):
    loss = loss_func(model(xb), yb)

    if opt is not None:
        loss.backward()
        opt.step()
        opt.zero_grad()

    return loss.item(), len(xb)

fit runs the necessary operations to train our model and compute the training and validation losses for each epoch. 

(Note that we always call model.train() before training, and model.eval() before inference, because these are used by layers such as nn.BatchNorm2d and nn.Dropout to ensure appropriate behaviour for these different phases.)



In [None]:
def fit(epochs, model, loss_func, opt, train_dl, valid_dl, device):
    counter = []
    loss_history = [] 
    iteration_number= 0
    
    for epoch in range(epochs):
        model.train()
        for i, data in enumerate(train_dl, 0):
            xb, yb = data
            loss_batch(model, loss_func, xb, yb, opt)
            
            if i %10 == 0 :
            print("Epoch number {}\n Current loss {}\n".format(epoch, loss_contrastive.data[0]))
            iteration_number +=10
            counter.append(iteration_number)
            loss_history.append(loss_contrastive.data[0])

        model.eval()
        with torch.no_grad():
            losses, nums = zip(
                *[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl]
            )
        val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)
        print(epoch, val_loss)

    show_plot(counter,loss_history)

"get_data" returns dataloaders for the training and validation sets.

In [None]:
#TODO: send the batches to device

def get_data(X, Y, C):
    # Datasets
    X = # IDs
    Y = # Labels

    #dataloader -> iterator
    training_set = Dataset(X['train'], Y)
    validation_set = Dataset(X['validation'], Y)

    return (
        DataLoader(train_ds, **C.data_params),
        DataLoader(valid_ds, **C.data_params),
    )

main(). 

Now, our whole process of obtaining the data loaders and fitting the model can be run in 3 lines of code:

In [None]:
train_dl, valid_dl = get_data(X, Y, Config)
model, opt = get_model(Config, device)
fit(Config.NUM_EPOCHS, model, loss_func, opt, train_dl, valid_dl)

In [None]:
#@title
model, opt = get_model()
loss_func  = ContrastiveLoss()
print(loss(model(xb), yb))


counter = []
loss_history = [] 
iteration_number= 0

for epoch in range(0, Config.NUM_EPOCHS):
    model.train()
    for i, xb1, xb2 in enumerate(train_dataloader, 0):
        pred1, pred2 = model(xb1, xb2)
        loss_contrastive = loss_func(pred1, pred2, yb)
        loss_contrastive.backward()
        opt.step()
        opt.zero_grad()
  
        if i %10 == 0 :
            print("Epoch number {}\n Current loss {}\n".format(epoch, loss_contrastive.data[0]))
            iteration_number +=10
            counter.append(iteration_number)
            loss_history.append(loss_contrastive.data[0])
    
    model.eval()
    with torch.no_grad():
        valid_loss = sum(loss_func(model(xb1, xb2), yb) for xb1, xb2, yb in validation_generator)
    print(epoch, valid_loss / len(validation_generator))     
       
show_plot(counter,loss_history)