In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset, DataLoader
import pickle, json
import numpy as np

import pandas as pd

In [2]:
model_type = "conversation" # abc or conversation

In [3]:
# with open(f'../data/raw/{model_type}_points.json', 'r') as f:
#     points = json.load(f)

# new_points = []
# labels = set()

# for point in points:
#     label = point['label']
#     data = point['data']

#     labels.add(label)

#     if label not in ("cool", "bye"):
#         new_points.append({
#             'label': label,
#             'data': data
#         })

# print(labels)

# with open(f'../data/raw/{model_type}_points.json', 'w') as f:
#     json.dump(new_points, f)

In [4]:
label2id = {}
id2label = {}

def load_data():
    global NUM_LABELS

    with open(f'../data/raw/{model_type}_points.json', 'r') as f:
        points = json.load(f)

    for item in points:
        label = item["label"]
        if label not in label2id:
            label2id[label] = len(label2id)
            id2label[len(id2label)] = label

    return pd.DataFrame(points).reset_index(drop=True)

data = load_data()
data

Unnamed: 0,data,label
0,"[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, ...",yes
1,"[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, ...",yes
2,"[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, ...",yes
3,"[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, ...",yes
4,"[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, ...",yes
...,...,...
2395,"[[0.061731815338134766, 1.0, 1.0], [-0.4242817...",easter_egg
2396,"[[0.06483232975006104, 1.0, 1.0], [-0.42371267...",easter_egg
2397,"[[0.061743855476379395, 1.0, 1.0], [-0.4254379...",easter_egg
2398,"[[0.06654179096221924, 1.0, 1.0], [-0.42671757...",easter_egg


In [5]:
# Save id2label dict
with open(f"../model/{model_type}_id2label.pkl", "wb") as f:
    pickle.dump(id2label, f)

id2label

{0: 'yes',
 1: 'no',
 2: 'hello',
 3: 'bye',
 4: 'how',
 5: 'old',
 6: 'why',
 7: 'you',
 8: 'me',
 9: 'explore',
 10: 'deaf',
 11: 'easter_egg'}

In [6]:
 # Convert labels to numbers
data['label'] = data['label'].apply(lambda x: label2id[x])
data['label'].value_counts()

0     200
1     200
2     200
3     200
4     200
5     200
6     200
7     200
8     200
9     200
10    200
11    200
Name: label, dtype: int64

In [7]:
def train_test_split(df, test_size=0.4):
    test_size = int(len(df) * test_size)
    train_size = len(df) - test_size

    # Shuffle data
    df = df.sample(frac=1).reset_index(drop=True)

    # Split data
    train = df[:train_size]
    test = df[train_size:]

    return train, test

def make_smaller(df, frac=1):
    return df.sample(frac=frac)

train, test = train_test_split(data)
train, test = make_smaller(train), make_smaller(test)

with open(f"../data/clean/{model_type}_train.pkl", "wb") as f:
    pickle.dump(train.to_dict('records'), f)

with open(f"../data/clean/{model_type}_test.pkl", "wb") as f:
    pickle.dump(test.to_dict('records'), f)

In [8]:
len(train), len(test)

(1440, 960)

In [9]:
train

Unnamed: 0,data,label
573,"[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, ...",2
55,"[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, ...",10
782,"[[0.11906874179840088, 1.0, 1.0], [-0.41815882...",11
39,"[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, ...",10
99,"[[-1.0, 1.0, -1.0], [-0.22503364086151123, 0.9...",9
...,...,...
350,"[[-1.0, -0.044448673725128174, 1.0], [-0.97291...",4
216,"[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, ...",6
983,"[[-1.0, 0.5546283721923828, 1.0], [-0.86920803...",4
73,"[[-1.0, 1.0, -1.0], [-0.3329070210456848, 0.99...",9


In [10]:
test

Unnamed: 0,data,label
2376,"[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, ...",0
2313,"[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, ...",0
1798,"[[-1.0, 0.6329032182693481, 1.0], [-0.84634947...",4
1503,"[[-1.0, 1.0, -1.0], [-0.26882362365722656, 0.9...",9
1532,"[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, ...",8
...,...,...
1906,"[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, ...",1
2021,"[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, ...",2
1730,"[[-1.0, 0.11205041408538818, 1.0], [-0.9938939...",4
2113,"[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, ...",6
