In [1]:
import io
import numpy as np
import torch
from torch import nn
import argparse
import time

from Dictionary import Dictionary
from Discriminator import Discriminator
from TrainModel import TrainModel
from Evaluator import Evaluator

In [2]:
#default & settings
useGPU = True
n_epoch_adv = 5

In [3]:
def load_embeddings(source_embedding_path, target_embedding_path, maxCount = 1e10):
    #Load embeddings
    #read txt embeddings for English(2519370 words with 300 dim) and Chinese(332647 words with 300 dim)
    word2id = {}     #e.g. u'\u5e74 = year
    vectors = []
    count = 0
    with io.open(target_embedding_path, 'r', encoding='utf-8', newline='\n', errors='ignore') as f:
        for i, line in enumerate(f):
            #print i,line
            #print i

            if count>=maxCount:
                break
            count += 1
            if i == 0:
                split = line.split()
            else:
                word, vect = line.rstrip().split(' ', 1) #stripe space from end
                #print word #real chars

                vect = np.fromstring(vect, sep=' ')

                if np.linalg.norm(vect) == 0:  # avoid to have null embeddings
                    vect[0] = 0.001 #add a little amount...
                
                word2id[word] = count-2
                vectors.append(vect[None])
    
#     print len(vectors[0]),word2id
    print "Finished loading", count, "words..."
    id2word = {v: k for k, v in word2id.items()}  #reverse of word2id
    dic = Dictionary(id2word, word2id, "zh")
    #print "len is",dic.__len__()
    embeddings = np.concatenate(vectors, 0)
    embeddings = torch.from_numpy(embeddings).float()
    return dic, embeddings

In [4]:
# load source embedding
source_embedding_path = "data/wiki.en.vec"
target_embedding_path = "data/wiki.zh.vec"
src_dic, _src_emb = load_embeddings(source_embedding_path,source_embedding_path, 100)
src_emb = nn.Embedding(len(src_dic), 300, sparse=True) #dim is set to 300..

# load target embedding
tgt_dic, _tgt_emb = load_embeddings(target_embedding_path,target_embedding_path, 100)
tgt_emb = nn.Embedding(len(tgt_dic), 300, sparse=True)


Finished loading 100 words...
Finished loading 100 words...


In [5]:
# Mapping
mapping = nn.Linear(300, 300, bias=False)

In [6]:
discriminator = Discriminator()

In [7]:
discriminator.layers

Sequential(
  (0): Dropout(p=0.1)
  (1): Linear(in_features=300, out_features=2048, bias=True)
  (2): LeakyReLU(0.2)
  (3): Dropout(p=0)
  (4): Linear(in_features=2048, out_features=2048, bias=True)
  (5): LeakyReLU(0.2)
  (6): Dropout(p=0)
  (7): Linear(in_features=2048, out_features=1, bias=True)
  (8): Sigmoid()
)

In [8]:
# use gpu
if useGPU:
    src_emb.cuda()
    tgt_emb.cuda()
    mapping.cuda()
    discriminator.cuda()

In [9]:
# do not normalize embeddings
# params.src_mean = normalize_embeddings(src_emb.weight.data, "")
# params.tgt_mean = normalize_embeddings(tgt_emb.weight.data, "")

In [10]:
#Now we have these four core part cuda: src_emb.cuda(), tgt_emb.cuda(), mapping.cuda(), discriminator.cuda()

### train model initialization

In [11]:
trainer = TrainModel(src_emb, tgt_emb, mapping, discriminator, src_dic, tgt_dic, 'sgd', 0.1)
#trainer = TrainModel(1)

In [12]:
mapping.parameters

<bound method Linear.parameters of Linear(in_features=300, out_features=300, bias=False)>

In [13]:
trainer.map_optimizer.param_groups

[{'dampening': 0,
  'lr': 0.1,
  'momentum': 0,
  'nesterov': False,
  'params': [Parameter containing:
   -3.6995e-02 -9.0221e-03  3.5143e-02  ...  -4.6831e-02  5.0774e-02 -5.4966e-02
    4.2324e-02 -4.2363e-02  1.4072e-02  ...   5.6803e-02 -1.4891e-02  7.0549e-03
    4.1286e-02  4.2584e-02 -1.0740e-02  ...  -5.0184e-02  3.4514e-02  1.4231e-02
                   ...                   ⋱                   ...                
    1.8819e-02 -8.3572e-03 -3.2718e-02  ...   1.6956e-02  1.0110e-02  3.5866e-02
   -5.6765e-02  7.4246e-03  4.5668e-02  ...   2.9385e-03  9.0126e-03 -1.9487e-02
   -5.5643e-03  1.5083e-02  1.5856e-02  ...  -1.2586e-02  1.7161e-02  3.4722e-02
   [torch.cuda.FloatTensor of size 300x300 (GPU 0)]],
  'weight_decay': 0}]

### Evaluator initialization

In [14]:
evaluator = Evaluator(trainer)

### Unsupervised Training

In [17]:
#Adversarial Training
print('--------- ADVERSARIAL TRAINING -------\n')
#epoch_size = 1000000
epoch_size = 1000
batch_size = 32
dis_steps = 5
for epoch in xrange(n_epoch_adv):
    print('Starting %i th epoch in adversarial training...' % epoch)
    tic = time.time()
    n_words_proc = 0
    stats = {'DIS_COSTS': []}
    for n_iter in range(0, epoch_size, batch_size):
        # discriminator training
        for _ in range(dis_steps):
            trainer.dis_step(stats)
    #print stats

--------- ADVERSARIAL TRAINING -------

Starting 0 th epoch in adversarial training...
{'DIS_COSTS': [0.32523685693740845, 0.3252228796482086, 0.32517892122268677, 0.3252641260623932, 0.325248658657074, 0.32522547245025635, 0.32519593834877014, 0.3251835107803345, 0.3251747488975525, 0.3252125382423401, 0.32525038719177246, 0.32523757219314575, 0.32518646121025085, 0.32520821690559387, 0.3252050280570984, 0.3252267837524414, 0.32522672414779663, 0.3252527713775635, 0.325233519077301, 0.3252032399177551, 0.3252009153366089, 0.3251890540122986, 0.32521581649780273, 0.32522889971733093, 0.3252912759780884, 0.32517626881599426, 0.3252330422401428, 0.3252061903476715, 0.3252313733100891, 0.32520920038223267, 0.32524776458740234, 0.32521766424179077, 0.325209379196167, 0.3252562880516052, 0.32526224851608276, 0.3252134919166565, 0.3253023624420166, 0.32526567578315735, 0.32522356510162354, 0.3252587914466858, 0.32521140575408936, 0.3252250552177429, 0.3251720070838928, 0.3253156244754791, 0.

{'DIS_COSTS': [0.3252421021461487, 0.32517629861831665, 0.32520410418510437, 0.32520902156829834, 0.32521775364875793, 0.3252309560775757, 0.32519781589508057, 0.3252214789390564, 0.3252304792404175, 0.3252553939819336, 0.32523566484451294, 0.32523924112319946, 0.3252105116844177, 0.32519590854644775, 0.32523518800735474, 0.3252221345901489, 0.32520586252212524, 0.3252124488353729, 0.3252161145210266, 0.3252643048763275, 0.32522889971733093, 0.32525092363357544, 0.32516974210739136, 0.32526132464408875, 0.32520920038223267, 0.3252355456352234, 0.3252120614051819, 0.3252214789390564, 0.32522061467170715, 0.32521963119506836, 0.3252031207084656, 0.3252352774143219, 0.3252440094947815, 0.3251942992210388, 0.32525724172592163, 0.3252253830432892, 0.3252583146095276, 0.32516956329345703, 0.32521867752075195, 0.32521840929985046, 0.32521915435791016, 0.32520025968551636, 0.3252279758453369, 0.32522404193878174, 0.32526201009750366, 0.3252737820148468, 0.3252585232257843, 0.3252064287662506, 