In [15]:
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pickle

with open('/Users/Phil/Documents/Frosh/CS224N/note2vec/data/jsb-chorales-16th.pkl', 'rb') as f:
    u = pickle._Unpickler(f)
    u.encoding = 'latin1'
    p = u.load()
    
print(data.keys())

dict_keys(['test', 'train', 'valid'])
33


In [9]:
vocab = set()
for dataset in data:
    for chorale in data[dataset]:
        for chord in chorale:
            vocab = vocab.union(chord)
w2i = {w: i for i, w in enumerate(vocab)}
i2w = {i: w for i, w in enumerate(vocab)}
notes = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']

def num2note(num):
    note_name = notes[num % 12]
    octave = str((num-12)//12)
    return note_name + octave

vocab_size = len(vocab)

print('Found', vocab_size, 'unique notes')

Found 52 unique notes


In [3]:
import numpy.random

counts = {w2i[note]:1 for note in vocab}
for dataset in data:
    for chorale in data[dataset]:
        for chord in chorale:
            for note in chord:
                counts[w2i[note]] += 1

total = sum([counts[k] for k in counts])
vals = [counts[k]/total for k in counts]
            
def getNegativeSample():
    return numpy.random.choice(list(counts.keys()), p=vals)

In [39]:
from torch.utils.data import TensorDataset, DataLoader
batch_size = 32

def create_skipgram_dataset(chorales):
    data = []
    for ch in chorales:
        for c in ch:
            for n in c:
                data += [(w2i[n], w2i[n1], 1) for n1 in c if n1 != n]
                # negative sample
                data += [(w2i[n], getNegativeSample(), 0) for _ in range(3)]
    dataset = TensorDataset(torch.tensor(skipgram_data, dtype=torch.long))
    loader = DataLoader(sg_dataset, batch_size=batch_size, shuffle=False, num_workers=1)
    return loader

sg_loader = create_skipgram_dataset(data['train'])

In [5]:
class SkipGram(nn.Module):
    def __init__(self, vocab_size, embed_size):
        super(SkipGram, self).__init__()
        self.embeddings = nn.Embedding(vocab_size, embed_size)
    
    def forward(self, focus, context):
        embed_focus = self.embeddings(focus)
        embed_focus = embed_focus.view((embed_focus.shape[0], 1, embed_focus.shape[1]))
        embed_context = self.embeddings(context)
        embed_context = embed_context.view((embed_context.shape[0], embed_context.shape[1], 1))
        scores = torch.bmm(embed_focus, embed_context)
        log_probs = F.logsigmoid(scores)
        return log_probs

In [27]:
embed_size = 8
learning_rate = 0.01
n_epoch = 20

def train_skipgram():
    losses = []
    loss_fn = nn.L1Loss()
    model = SkipGram(vocab_size, embed_size)
    print(model)
    # model.load_state_dict(torch.load('model'))
    optimizer = optim.SGD(model.parameters(), lr=learning_rate)
    
    for epoch in range(n_epoch):
        total_loss = 0.0
        for i, sample_batched in enumerate(sg_loader):
            sample_batched = sample_batched[0]
            in_w_var = Variable(sample_batched[:,0])
            ctx_w_var = Variable(sample_batched[:,1])
            # print(in_w_var.shape)
            model.zero_grad()
            log_probs = model(in_w_var, ctx_w_var)
            loss = loss_fn(log_probs, Variable(sample_batched[:,2].float()))
            
            loss.backward()
            optimizer.step()
            
            total_loss += loss.data
    
        losses.append(total_loss.item())
        print('Epoch:', epoch, 'Loss:', total_loss.item())
        torch.save(model.state_dict(), 'model8')
    return model, losses

sg_model, sg_losses = train_skipgram()

SkipGram(
  (embeddings): Embedding(52, 8)
)
Epoch: 0 Loss: 9377.3212890625
Epoch: 1 Loss: 5754.92919921875
Epoch: 2 Loss: 5353.03271484375
Epoch: 3 Loss: 5214.1640625
Epoch: 4 Loss: 5136.5703125
Epoch: 5 Loss: 5083.423828125
Epoch: 6 Loss: 5045.63134765625
Epoch: 7 Loss: 5019.3740234375
Epoch: 8 Loss: 5001.21728515625
Epoch: 9 Loss: 4988.39599609375
Epoch: 10 Loss: 4979.05029296875
Epoch: 11 Loss: 4972.03515625
Epoch: 12 Loss: 4966.55859375
Epoch: 13 Loss: 4962.13427734375
Epoch: 14 Loss: 4958.53515625
Epoch: 15 Loss: 4955.49658203125
Epoch: 16 Loss: 4952.90087890625
Epoch: 17 Loss: 4950.65771484375
Epoch: 18 Loss: 4948.7119140625
Epoch: 19 Loss: 4946.97607421875
Epoch: 20 Loss: 4945.44140625
Epoch: 21 Loss: 4944.07275390625
Epoch: 22 Loss: 4942.82177734375
Epoch: 23 Loss: 4941.6923828125
Epoch: 24 Loss: 4940.673828125
Epoch: 25 Loss: 4939.73779296875
Epoch: 26 Loss: 4938.87744140625
Epoch: 27 Loss: 4938.07275390625
Epoch: 28 Loss: 4937.34375
Epoch: 29 Loss: 4936.66552734375


In [38]:
import numpy as np
# from sklearn.decomposition import PCA
model = SkipGram(vocab_size, embed_size)
model.load_state_dict(torch.load('model{}'.format(embed_size)))

embeddings = np.array(model.embeddings.weight.data)
# pca = PCA(n_components=2)
# pca.fit(embeddings)

with open('embeddings{}.txt'.format(embed_size), 'w') as f:
    for i in range(len(embeddings)):
        f.write("{}\t{}\t{}\t{}\n".format(embeddings[i,0], embeddings[i,1], embeddings[i,2], embeddings[i,3]))

with open('meta.tsv', 'w') as f:
    for i in range(len(i2w)):
        f.write('{}\n'.format(num2note(i2w[i])))
        