# Explore network (MLP)

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

## Imports

In [None]:
import os

import numpy as np
import pandas as pd
import torch

from src.data.dataset import MovieDataset
from src.utils.const import DATA_DIR, SEED

### Useful path to data

In [None]:
ROOT_DIR = os.path.join(os.getcwd(), '..')
PROCESSED_DIR = os.path.join(ROOT_DIR, DATA_DIR, 'processed')

### Repeatability

In [None]:
torch.manual_seed(SEED)
np.random.seed(SEED)
torch.use_deterministic_algorithms(True)
torch.backends.cudnn.benchmark = False

## Import final dataset

In [None]:
df = pd.read_parquet(os.path.join(PROCESSED_DIR, 'final.parquet'))

## Work with Dataset and DataLoader

In [None]:
from matplotlib import pyplot as plt


def plot_scatters(data):
    X_train, X_test, X_val = data.X_train, data.X_test, data.X_val

    plt.scatter(X_train[:, 0], X_train[:, 1], c='r')
    plt.scatter(X_val[:, 0], X_val[:, 1], c='g')
    plt.scatter(X_test[:, 0], X_test[:, 1], c='b')

In [None]:
dataset = MovieDataset(df)
plot_scatters(dataset)

### MinMaxScaler and StandardScaler

In [None]:
dataset = MovieDataset(df, scale_method='min-max')
plot_scatters(dataset)

In [None]:
dataset = MovieDataset(df, scale_method='standardization')
plot_scatters(dataset)

## normalize

In [None]:
dataset = MovieDataset(df, norm='l1')
plot_scatters(dataset)

In [None]:
dataset = MovieDataset(df, norm='l2')
plot_scatters(dataset)

In [None]:
dataset = MovieDataset(df, norm='max')
plot_scatters(dataset)

### LinearDiscriminantAnalysis

In [None]:
dataset = MovieDataset(df)

In [None]:
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis

X_train, X_test, y_train, y_test = dataset.X_train, dataset.X_test, dataset.y_train, dataset.y_test

lda = LinearDiscriminantAnalysis()
lda.fit(X_train, y_train)

X_train_t = lda.transform(X_train)
X_test_t = lda.transform(X_test)

In [None]:
plt.scatter(X_train_t[:, 0], X_train_t[:, 1], c=y_train)
plt.show()
plt.scatter(X_test_t[:, 0], X_test_t[:, 1], c=y_test)
plt.show()