### 1. Doing necessary imports and setting up

In [None]:
# pip3 install stable-baselines3 hydra-core torchtoolbox
# pip3 install torch torchvision torchaudio

In [None]:
# %load_ext autoreload
# %autoreload 2

In [None]:
import sys
sys.path.append("..")

import logging
import os
import time
import datetime
import random
import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
import csv
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
from os.path import join, dirname
from sklearn.model_selection import train_test_split
from imitation_learning.dataset.frame_dataset import Float115Dataset
from models.mlp import MLPModel
from utils.solver import Solver
from gfootball.env.wrappers import Simple115StateWrapper
from torch.utils.tensorboard import SummaryWriter
from imitation_learning.dataset.dict_dataset import Float115Dataset as DictDataset



# import hydra
# import wandb
# from omegaconf import DictConfig
# import torchtoolbox.transform as transforms
# import torchvision
# import pickle

In [None]:
# logging.basicConfig(level=logging.INFO)
# logger = logging.getLogger(__name__)

### 2.1 Loading Dataset from dict files

In [None]:
os.environ['PYTHONHASHSEED'] = str(42)
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

In [None]:
dataset = pd.read_csv('../data/frames.csv', header=None)[0]

In [None]:
train, val, test = np.split(dataset.sample(frac=1, random_state=42), [
                                    int(.6 * len(dataset)), int(.8 * len(dataset))])

train_frames, val_frames, test_frames = np.array(train, dtype='str'), np.array(val, dtype='str'), np.array(test, dtype='str')

In [None]:
print(f"Total memory occupied {dataset.nbytes/1000000 + train.nbytes/1000000 + val.nbytes/1000000 + test.nbytes/1000000} MB")
dataset.nbytes/1000000, train.nbytes/1000000, val.nbytes/1000000, test.nbytes/1000000

In [None]:
dataset_path = '/home/ssk/Study/GRP/dataset/npy_files'
train_dataset, val_dataset, test_dataset = Float115Dataset(train_frames, dataset_path), \
    Float115Dataset(val_frames, dataset_path), \
        Float115Dataset(test_frames, dataset_path)

print("Number of training samples:", len(train_dataset))
print("Number of validation samples:", len(val_dataset))
print("Number of test samples:", len(test_dataset))    

In [None]:
del(train)
del(val)
del(test)
del(dataset)
del(train_frames)
del(val_frames)
del(test_frames)

In [None]:
train_dataset

In [None]:
train_loader = DataLoader(dataset=train_dataset, batch_size=128, shuffle=True, num_workers=4)
val_loader = DataLoader(dataset=val_dataset, batch_size=128, shuffle=True, num_workers=4)
test_loader = DataLoader(dataset=test_dataset, batch_size=128, shuffle=True, num_workers=4)

In [None]:
batch_size = 128
lr = 1e-3
# hidden_size = 100
# std = 1.
epochs = 20


model = MLPModel()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = ReduceLROnPlateau(optimizer=optimizer, mode='max', patience=1, verbose=True, factor=0.2)
criterion = nn.CrossEntropyLoss()

### 3. Training the model

In [None]:
writer = SummaryWriter()

In [None]:
# solver = Solver(model, train_loader, val_loader, loss_func=criterion, optimizer=optimizer, learning_rate=lr, writer=writer)

In [None]:
# solver.train(epochs=epochs)

In [None]:
# y_out, _ = model(X_test)
# l1_loss = L1()
# mse_loss = MSE()
# print("L1 loss on test set AFTER training: {:,.0f}".format(l1_loss(rescale(y_out), rescale(y_test))[0].mean() ))
# print("MSE loss on test set AFTER training: {:,.0f}".format(mse_loss(rescale(y_out), rescale(y_test))[0].mean() ))

In [None]:
# y_out = solver.get_dataset_prediction(test_loader)
# l1_loss = L1()
# mse_loss = MSE()
# print("L1 loss on test set BEFORE training: {:,.0f}".format(l1_loss(rescale(y_out), rescale(y_test))[0].mean() ))
# print("MSE loss on test set BEFORE training: {:,.0f}".format(mse_loss(rescale(y_out), rescale(y_test))[0].mean() ))

In [None]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = MLPModel()
# model = model.to(device)

# optim = torch.optim.Adam(model.parameters(), lr=0.01)
# scheduler = ReduceLROnPlateau(optimizer=optim, mode='max', patience=1, verbose=True, factor=0.2)
# criterion = nn.CrossEntropyLoss()

In [None]:
# wandb.init(project="GRF_imitation_learning")
# wandb.watch(model)

In [None]:
# # add test data to test before training
# X_test = [test_dataset[i] for i in range((len(test_dataset)))]
# X_test = np.stack(X_test, axis=0)
# y_test = [test_dataset[i]['target'] for i in range((len(test_dataset)))]
# y_test = np.stack(y_test, axis=0)

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
def train_model(num_epochs):
    train_loss_values = []
    eval_loss_values = []
    training_accuracy = []
    for epoch in range(num_epochs):
        start_time = time.time()
        correct = 0
        correct_eval_samples = 0
        epoch_loss = 0
        eval_loss = 0
        running_train_loss = 0.0
        running_eval_loss = 0.0

        # Train
        model.train()
        print("Running the training")
        start = time.perf_counter()
        for i, (x, y) in enumerate(tqdm(train_loader)):
            x = torch.as_tensor(x, device=device, dtype=torch.float32)
            y = torch.as_tensor(y, device=device, dtype=torch.float32)
            optimizer.zero_grad()
            end = time.perf_counter()
            z = model(x)
            loss = criterion(z, y.long())
            writer.add_scalar("Loss/train", loss, epoch)
            loss.backward()
            optimizer.step()
            pred = torch.argmax(z, dim=1)
#             if (i+1) % 100 == 0:
#                 print(f"train loss: {loss.item()}", end='\n')
            correct += (pred.cpu() == y.long().cpu()).sum()
            running_train_loss += loss.item() * x.shape[0]

        epoch_accuracy = correct/len(train_dataset)
        epoch_loss = running_train_loss / len(train_dataset)
        train_loss_values.append(epoch_loss)
        training_accuracy.append(epoch_accuracy)

        # Eval
        model.eval()
        print("Running the validation")
        with torch.no_grad():
            for i, (x, y) in enumerate(tqdm(val_loader)):
                x = torch.as_tensor(x, device=device, dtype=torch.float32)
                y = torch.as_tensor(y, device=device, dtype=torch.float32)
                z = model(x)
                loss = criterion(z, y.long())
                running_eval_loss += loss.item() * x.shape[0]
                writer.add_scalar('Loss/test', eval_loss, epoch)
        eval_loss = running_eval_loss / len(val_dataset)
        eval_loss_values.append(eval_loss)
        print('Epoch {:03}/{}: | Train Loss: {:.3f} | Eval Loss: {:.3f} | Epoch accuracy: {:.2f} | Training time: {}'.format(
            epoch + 1, num_epochs, epoch_loss, eval_loss, epoch_accuracy, str(datetime.timedelta(seconds=time.time() - start_time))[:7]))
        # wandb.log({'train loss': epoch_loss, 'val loss': eval_loss})
    return train_loss_values, eval_loss_values, training_accuracy


In [None]:
train_loss_values, eval_loss_values, training_accuracy = train_model(50)
# writer.flush()

In [None]:
# Just a sanity check
len(train_loss_values), len(eval_loss_values), len(training_accuracy)

In [None]:
PATH = '../saved_models/mlp_model.pth'
torch.save(model.state_dict(), PATH)

In [None]:
torch.save(model, '../saved_models/mlp_full_model.pth')

In [None]:
plt.plot(train_loss_values, label='train_loss')
plt.plot(eval_loss_values, label='val_loss')
plt.xlabel('Epoch')
plt.ylabel('loss')
plt.legend()
# plt.show()
plt.savefig('imitation.jpg')

In [None]:
plt.plot(training_accuracy)
plt.show()

In [None]:
writer.close()

In [None]:
!tensorboard --logdir=runs

In [None]:
# replay_files_path = '../data/replay_files'
# replay_files = sorted(os.listdir(replay_files_path))
# replay_files.pop(0)
# # replay_files = replay_files[0:1]
# print(f"total replay files: {len(replay_files)}")
# # replay_files = replay_files[0:1]

# start = time.perf_counter()
# prepare_dict_of_obs_from_replay_files(replay_files, replay_files_path)
# end =  time.perf_counter()

# print(f"Total time needed to process {len(replay_files)}: {end-start}s")
# print(f"Time needed to process a single file: {(end-start)/len(replay_files)}s")

In [None]:
### !!!!!!!!!!!!!!!!!!! Do not delete this cell !!!!!!!!!!!!!!!!!!!
import os
from tqdm import tqdm
import pickle
import time
import numpy as np
from gfootball.env.wrappers import Simple115StateWrapper

# Here we will write the pickle files
def prepare_dict_of_obs_from_replay_files(replay_files, replay_files_path):
    obs_save_dir = '/home/ssk/Study/GRP/dataset/dict_files'
    # replay_files_path = 'dataset/replay_files'

    if not os.path.exists(obs_save_dir):
        os.mkdir(obs_save_dir)

    for replay in tqdm(replay_files):
        replay_dict = {}
        with open(os.path.join(replay_files_path, replay), 'rb') as pkl_file:
            episode_data = pickle.load(pkl_file)

        episode_no = replay.split('.')[0]
        episode = episode_data['observations']
        episode['active'] = episode_data['players'][0]['active']
        episode_length = 3002
        raw_obs = {}

#         episode_dir = os.path.join(obs_save_dir, episode_no)

        for step in range(episode_length):
            for (key, item) in episode.items():
                raw_obs[key] = item[step]

            float115_frame =  Simple115StateWrapper.convert_observation([raw_obs], True)[0].tolist()
            action = episode_data['players'][0]['action'][step]
            
            frame_name = episode_no+f'_{step}'
            if len(action) != 0:
                float115_frame.extend(action)
                replay_dict[frame_name] = np.array(float115_frame)
#                 fram_save_path = os.path.join(episode_dir, frame_name)
#                 np.save(fram_save_path, np.array(float115_frame))
        
#         dict_save_path = os.path.join(replay_dict, replay)
        with open(os.path.join(obs_save_dir, replay), 'wb') as f:
            pickle.dump(replay_dict, f)


### Create dictionaries of obs for each episode file

In [None]:
with open('/home/ssk/Study/GRP/dataset/dict_files/3720620.p', 'rb') as handle:
    obs = pickle.load(handle)

In [None]:
# obs

In [None]:
# obs['3720620_1'][0:115]

In [None]:
def save_batched_data():
#     root_save_dir = '/home/ssk/Study/GRP/dataset/batched_data'
#     train_save_dir = join(root_save_dir, 'train')
#     val_save_dir = join(root_save_dir, 'val')
#     test_save_dir = join(root_save_dir, 'test')
    
#     if not os.path.exists(root_save_dir):
#         os.mkdir(root_save_dir)
#     if not os.path.exists(train_save_dir):
#         os.mkdir(train_save_dir)
#     if not os.path.exists(val_save_dir):
#         os.mkdir(val_save_dir)
#     if not os.path.exists(test_save_dir):
#         os.mkdir(test_save_dir)
    # Train
    start = time.perf_counter()
    print(f"Processing Train dataset")
    train_batch_file_prefix = 'train_batch'
    for i, (x, y) in enumerate(tqdm(train_loader)):
        if i==50:
            break
        end = time.perf_counter()
        del(x)
        del(y)
    print(f"Time to load {i} batches containing {128 * i} datapoints: {(end-start):.4} s")
    print(f"Average Time to load a single batch containing {128} datapoints: {(end-start)/(i):.4f} s")
#         print(f"Size of the read batch: {x.element_size() * x.nelement()/1000000} MB")
#         train_batch_name = 
#         train = {}
#         train['obs'] = x
#         train['act'] = y
        

    # Val
#     print(f"Processing Val dataset")
#     for i, (x, y) in enumerate(val_loader):
#         pass
#     # Test
#     print(f"Processing Test dataset")
#     for i, (x, y) in enumerate(test_loader):