<a href="https://colab.research.google.com/github/vinnik-dmitry07/Chatbot/blob/main/train_chatbot.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!nvidia-smi
!pip install --quiet parlai

In [None]:
from pathlib import Path

GDRIVE_ROOT = Path('/content/drive/MyDrive/')
SAVE_DIR = GDRIVE_ROOT / 'chatbot_model'
DATA_DIR = GDRIVE_ROOT / 'chatbot_data'
SAVE_EVERY_N_MINUTES = 40  # ~ +1.23 GB from drive every 40 minutes

In [None]:
from datetime import timedelta

EPISODE_DT = timedelta(minutes=3)  # change to split messages in separate dialogues if time delta is greater than EPISODE_DT
TRAIN_PART, TEST_PART, VALID_PART = 0.996, 0.002, 0.002

assert TRAIN_PART + TEST_PART + VALID_PART == 1

In [None]:
from google.colab import drive

drive.mount(str(GDRIVE_ROOT.parent))

Mounted at /content/drive


In [None]:
import json

with open(DATA_DIR / 'result.json', 'r', encoding='utf8') as f:
    raw_messages = json.load(f)['messages']

In [None]:
from datetime import datetime

filtered_messages = []
for msg in raw_messages:
    if (
            'from' in msg and
            'from_id' in msg and
            'mime_type' not in msg and
            msg['text'] and
            isinstance(msg['text'], str) and
            len(msg['text']) < 50
    ):
        msg1 = msg.copy()
        msg1['date'] = datetime.strptime(msg1['date'], '%Y-%m-%dT%H:%M:%S')
        filtered_messages.append(msg1)

In [None]:
import re

joined_messages = []
for i in range(len(filtered_messages)):
    alphanum_text = re.sub(r'[^A-Za-z0-9 ]+', '', filtered_messages[i]['text']).strip()
    if alphanum_text:
        if (    
                joined_messages and    
                filtered_messages[i - 1]['from_id'] == filtered_messages[i]['from_id'] and
                filtered_messages[i - 1]['date'] - filtered_messages[i]['date'] <= EPISODE_DT
        ):
            joined_messages[-1]['text'] += ' ' + alphanum_text
        else:
            new_message = filtered_messages[i].copy()
            new_message['text'] = alphanum_text
            joined_messages.append(new_message)

In [None]:
def partition(alist, indices):
    return [alist[a:b] for a, b in zip([0] + indices, indices + [None])]

In [None]:
def save_jsonl(messages, suffix, human_readable=False):
    time_diffs = [messages[i + 1]['date'] - messages[i]['date'] for i in range(len(messages) - 1)]
    split_positions = [i + 1 for i in range(len(time_diffs)) if time_diffs[i] > EPISODE_DT]
    episodes = partition(messages, split_positions)
    print(f'{suffix} episodes: {len(episodes)}, messages: {len(messages)}')

    with open(DATA_DIR / f'data_{suffix}.jsonl', 'w', **({'encoding': 'utf8'} if human_readable else {})) as outfile:
        for episode in episodes:
            dialog = [{'id': i % 2, 'text': msg['text']} for i, msg in enumerate(episode)]
            episode = {'dialog': [dialog]}
            json.dump(episode, outfile, **({'ensure_ascii': False} if human_readable else {}))
            outfile.write('\n')

In [None]:
import numpy as np

train, test, valid = np.split(joined_messages, [
    int(TRAIN_PART * len(joined_messages)),
    int((TRAIN_PART + TEST_PART) * len(joined_messages)),
])

save_jsonl(train, suffix='train')
save_jsonl(test, suffix='test')
save_jsonl(valid, suffix='valid')

train episodes: 424, messages: 754917
test episodes: 1, messages: 1516
valid episodes: 1, messages: 1516


In [None]:
import os

os.environ['SAVE_DIR'] = str(SAVE_DIR)
!rm --recursive --force $SAVE_DIR
!mkdir --parents $SAVE_DIR


from parlai.scripts.train_model import TrainModel

TrainModel.main(
    task='jsonfile',
    jsonfile_datapath=str(DATA_DIR / 'data'),
    jsonfile_datatype_extension=True,

    model='transformer/generator',
    model_file=str(SAVE_DIR / 'model'),
    
    init_model='zoo:tutorial_transformer_generator/model',

    n_heads=16, n_layers=8, n_positions=512, text_truncate=512,
    label_truncate=128, ffn_size=2048, embedding_size=512,
    activation='gelu', variant='xlm',
    dict_lower=True, dict_tokenizer='bpe',
    dict_file='zoo:tutorial_transformer_generator/model.dict',
    learn_positional_embeddings=True,
    
    lr=1e-5, optimizer='adam',
    warmup_updates=5000,
    validation_metric='ppl',
    validation_every_n_secs=SAVE_EVERY_N_MINUTES * 60,  # running eval: valid
    # save_every_n_secs=60,  # saving model checkpoint

    batchsize=12, fp16=True, fp16_impl='mem_efficient',
    
    skip_generation=True,
    
    dynamic_batching='full',

    label_turns='both',  # https://parl.ai/docs/core/teachers.html#parlai.core.teachers.ConversationTeacher
)