# Training Dialogue Encoder

## Data Validation

In [14]:
import os
import json


def validate(path_a, path_b, n_chunks=80):
    chunks_a = sorted([filename for filename in os.listdir(path_a) if filename.endswith('.json')])[:n_chunks]
    chunks_b = sorted([filename for filename in os.listdir(path_b) if filename.endswith('.json')])[:n_chunks]
    
    assert len(chunks_a) == len(chunks_b), 'chunk numbers must match'
        
    def parse(dia):
        speaker_alias = "AB"
        return '\n'.join([f'[{speaker_alias[item["speaker"]]}] {item["utterance"]}' for item in dia])

    for pos_chunk, neg_chunk in zip(chunks_a, chunks_b):
        a = json.load(open(os.path.join(path_a, pos_chunk), 'r'))
        b = json.load(open(os.path.join(path_b, neg_chunk), 'r'))
        assert len(a) == len(b), 'dialogue sizes must match'
        for dia_a, dia_b in zip(a, b):
            assert parse(dia_a) != parse(dia_b), 'dialogues must not match'

In [15]:
path_original = '/home/alekseev_ilya/dialogue-augmentation/nup/dialogues/train'
path_positive = '/home/alekseev_ilya/dialogue-augmentation/augmented/insert'
path_negative = '/home/alekseev_ilya/dialogue-augmentation/augmented/replace'

In [16]:
validate(path_positive, path_negative)

In [17]:
validate(path_positive, path_original)

In [18]:
validate(path_negative, path_original)

AssertionError: dialogues must not match

## Model

In [None]:
import os
import json


n_chunks = 80
chunks_original = sorted([filename for filename in os.listdir(path_original) if filename.endswith('.json')])[:n_chunks]
chunks_positive = sorted([filename for filename in os.listdir(path_positive) if filename.endswith('.json')])[:n_chunks]

In [None]:
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import torch
torch.set_float32_matmul_precision('medium')    

In [None]:
from nup.models.dialogue import SimpleDialogueEncoder


encoder_name = 'xlnet-base-cased'
model = SimpleDialogueEncoder(encoder_name)

In [None]:
# from nup.models.dialogue import HSSAConfig, HSSADM
# config = HSSAConfig(
#     max_ut_embeddings=28,
#     casual_utterance_attention=False,
# )
# dialogue_model = HSSADM(
#     'microsoft/mpnet-base',
#     config,
#     pool_utterance_level=False,
#     pool_dialogue_level=True
# ).cuda()
# dialogue_model.requires_grad_(False)