In [None]:
from google.colab import drive
drive.mount('/content/gdrive/')

Mounted at /content/gdrive/


In [None]:
import os
import sys

path2drive = '/content/gdrive/My Drive'
path2data = os.path.join(path2drive, 'AMAZON_FASHION.json')
output = os.path.join(path2drive, 'output')

checkpoints_dir = os.path.join(output, 'checkpoints')
losses_file = os.path.join(output, 'losses.txt')

sys.path.append(path2drive)

In [None]:
if not os.path.exists(output):
    os.mkdir(output)

if not os.path.exists(checkpoints_dir):
    os.mkdir(checkpoints_dir)

In [None]:
from utils import *
from train import train

import torch
import torch.optim as optim

indices = [int(f[11:-4]) for f in os.listdir(checkpoints_dir)] + [-1, 0]
checkpoint_num = max(indices)

if not checkpoint_num:
    initial_epoch = 1
    initial_sentence_num = 0

    model = Word2Vec().to(DEVICE)
    ds = DataSet(path2data, initial_seed=initial_epoch)

    optimizer = optim.Adam(model.parameters(), LR)
    scheduler = optim.lr_scheduler.StepLR(
        optimizer, STEP_SIZE,
        gamma=(TARGET_LR / LR) ** (STEP_SIZE / (EPOCHS * len(ds)))
    )
else:
    filename = f'checkpoint-{checkpoint_num}.pth'
    checkpoint_dict = torch.load(os.path.join(checkpoints_dir, filename), map_location=DEVICE)

    initial_epoch = checkpoint_dict['initial_epoch']
    initial_sentence_num = checkpoint_dict['initial_sentence_num']

    model = checkpoint_dict['model']
    ds = DataSet(path2data, initial_seed=checkpoint_dict['initial_epoch'])

    optimizer = checkpoint_dict['optimizer']
    scheduler = checkpoint_dict['scheduler']

In [None]:
model.train()

training_iter = train(model, ds, optimizer, scheduler, initial_epoch, initial_sentence_num)
for loss, epoch, sentence_num in training_iter:
    checkpoint_num += 1

    with open(losses_file, 'a') as file:
        file.write(f'{loss.item()}\n')

    filename = f'checkpoint-{checkpoint_num}.pth'
    torch.save({
        'model': model,
        'optimizer': optimizer,
        'scheduler': scheduler,
        'loss': loss,
        'initial_epoch': epoch,
        'initial_sentence_num': sentence_num + 1
    }, os.path.join(checkpoints_dir, filename))

model.eval()

Sentences:   0%|          | 0/524648 [00:00<?, ?it/s]

In [None]:
torch.save(model.state_dict(), os.path.join(output, 'model.pth'))
vector = Vector(model.embedding, ds.vocabulary)
torch.save(vector, os.path.join(output, 'vector.pth'))