In [1]:
import numpy as np
import pandas as pd

from torch.utils.data import DataLoader
import math
from sentence_transformers import SentenceTransformer, losses, InputExample
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
from datasets import DatasetDict
import logging

from datetime import datetime

from os import getenv
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(sys.path[0])))

from src.paths import get_project_root, abs_path
import wandb
from dotenv import load_dotenv
load_dotenv()

True

In [2]:
%env "WANDB_NOTEBOOK_NAME" "train_sbert"

model_name = getenv("SBERT_MODEL", 'all-mpnet-base-v2')
train_batch_size = int(getenv("SBERT_BATCH_SIZE", 12))
num_epochs = int(getenv("SBERT_N_EPOCHS", 1))
evaluation_steps = int(getenv("SBERT_EVALUATION_STEPS", 1000))

model_save_path = abs_path("models", f'{model_name}-{datetime.now().strftime("%Y-%m-%d_%H-%M-%S")}')

env: "WANDB_NOTEBOOK_NAME"="train_sbert"


In [3]:
wandb.login() # relies on WANDB_API_KEY env var
run = wandb.init(
    project="ea-forum-analysis", job_type="training", dir=get_project_root(),
    config={
        "model_name": model_name,
        "train_batch_size": train_batch_size,
        "num_epochs": num_epochs
    }
)

[34m[1mwandb[0m: Currently logged in as: [33mvpetukhov[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [8]:
art = run.use_artifact("post_pairs:latest")
run.config.update({'data_version': art.version})

data = DatasetDict.load_from_disk(art.download())
data

[34m[1mwandb[0m: Downloading large artifact post_pairs:latest, 403.55MB. 7 files... 
[34m[1mwandb[0m:   7 of 7 files downloaded.  
Done. 0:0:0.0


DatasetDict({
    train: Dataset({
        features: ['src_text', 'dst_text', 'src_post_id', 'dst_post_id', 'sims'],
        num_rows: 174991
    })
    dev: Dataset({
        features: ['src_text', 'dst_text', 'src_post_id', 'dst_post_id', 'sims'],
        num_rows: 4220
    })
})

In [29]:
model = SentenceTransformer(model_name).cuda()
wandb.watch(model, log="all", log_freq=1000)

[]

In [38]:
train_samples = [InputExample(texts=ts, label=float(l)) for ts,l in zip(zip(data['train']['src_text'], data['train']['dst_text']), data['train']['sims'])]
train_loader = DataLoader(train_samples, batch_size=train_batch_size, shuffle=True)
train_loss = losses.CosineSimilarityLoss(model=model)

In [39]:
evaluator = EmbeddingSimilarityEvaluator(data['dev']['src_text'], data['dev']['dst_text'], data['dev']['sims'])

In [40]:
warmup_steps = math.ceil(len(train_loader) * num_epochs * 0.1) #10% of train data for warm-up
warmup_steps

1459

In [44]:
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

In [53]:
score = evaluator(model)
wandb.log({'score': score, 'epoch': -1, 'steps': 0})
score

0.5504928612334004

In [43]:
model.fit(
    train_objectives=[(train_loader, train_loss)],
    evaluator=evaluator,
    epochs=num_epochs,
    evaluation_steps=evaluation_steps,
    warmup_steps=warmup_steps,
    output_path=model_save_path,
    callback=lambda score, epoch, steps: wandb.log({'score': score, 'epoch': epoch, 'steps': steps})
)

Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

Iteration:   0%|          | 0/14583 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
art = wandb.Artifact("sbert", type="model", metadata={'model_name': model_name, 'data_version': art.version})
art.add_dir(model_save_path)
run.log_artifact(art, aliases=[model_name])

In [None]:
wandb.finish()

## Log the Colab model

In [3]:
wandb.login() # relies on WANDB_API_KEY env var
run = wandb.init(
    project="ea-forum-analysis", job_type="training", dir=get_project_root()
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mvpetukhov[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [10]:
art = wandb.Artifact("sbert", type="model", description="Model trained on Colab")
art.add_dir(abs_path("models", "n8-all-mpnet-base-v2-2022-11-03_09-01-58"))
run.log_artifact(art, aliases=["n8-all-mpnet-base-v2"])

[34m[1mwandb[0m: Adding directory to artifact (/home/vpetukhov/other/Consulting/SEADS/EAForumExperiments/models/n8-all-mpnet-base-v2-2022-11-03_09-01-58)... Done. 1.6s


<wandb.sdk.wandb_artifacts.Artifact at 0x7f8daf9a0eb0>

In [11]:
wandb.finish()