### 1. Necessary imports

In [1]:
import sys
sys.path.append("..")

import logging
import os
# import time
# import datetime
import random
import pandas as pd
import numpy as np
import torch
# import torch.nn.functional as F
import torch.nn as nn
# import torch.optim as optim
import csv
# import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
# from torch.optim.lr_scheduler import ReduceLROnPlateau
# from tqdm import tqdm
# from os.path import join, dirname
# from sklearn.model_selection import train_test_split
from imitation_learning.dataset.frame_dataset import Float115Dataset
from models.mlp import MLPModel
from utils.solver import Solver
# from gfootball.env.wrappers import Simple115StateWrapper
from utils.solver import Solver
from torch.utils.tensorboard import SummaryWriter
# from imitation_learning.dataset.dict_dataset import Float115Dataset as DictDataset
# import hydra
# import wandb
# from omegaconf import DictConfig
# import torchtoolbox.transform as transforms
# import torchvision
# import pickle

In [2]:
logging.basicConfig(format='%(asctime)s - %(message)s', level=logging.INFO)

In [3]:
# Setting the seeds for result reproducibility
os.environ['PYTHONHASHSEED'] = str(42)
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

In [4]:
# Loading the dataset containing names of all the frames
logging.info("Reading the dataset")
dataset = pd.read_csv('../data/frames.csv', header=None)[0]
logging.info("Dataset loaded inot the memory")
# Creating Train, Val and Test Dataset
train, val, test = np.split(dataset.sample(frac=1, random_state=42), [
                                    int(.6 * len(dataset)), int(.8 * len(dataset))])

train_frames, val_frames, test_frames = np.array(train, dtype='str'),\
                                        np.array(val, dtype='str'),\
                                        np.array(test, dtype='str')

dataset_path = '/home/ssk/Study/GRP/dataset/npy_files'
train_dataset, val_dataset, test_dataset = Float115Dataset(train_frames, dataset_path), \
                                            Float115Dataset(val_frames, dataset_path), \
                                            Float115Dataset(test_frames, dataset_path)

logging.info(f"Number of training samples: {len(train_dataset)}")
logging.info(f"Number of validation samples: {len(val_dataset)}")
logging.info(f"Number of test samples: {len(test_dataset)}")

2022-02-09 21:46:40,995 - Reading the dataset
2022-02-09 21:46:47,086 - Dataset loaded inot the memory
2022-02-09 21:46:56,333 - Number of training samples: 8642880
2022-02-09 21:46:56,333 - Number of validation samples: 2880960
2022-02-09 21:46:56,334 - Number of test samples: 2880960


In [5]:
del(train)
del(val)
del(test)
del(dataset)
del(train_frames)
del(val_frames)
del(test_frames)

In [6]:
# Creating the dataloaders
train_loader = DataLoader(dataset=train_dataset, batch_size=128, shuffle=True, num_workers=4)
val_loader = DataLoader(dataset=val_dataset, batch_size=128, shuffle=True, num_workers=4)
test_loader = DataLoader(dataset=test_dataset, batch_size=128, shuffle=True, num_workers=4)


In [7]:

lr = 1e-3

# Loading the model and defining different parameters for the training
model = MLPModel()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

# scheduler = ReduceLROnPlateau(optimizer=optimizer, mode='max', patience=1, verbose=True, factor=0.2)
criterion = nn.CrossEntropyLoss()

In [8]:
writer = SummaryWriter()
solver = Solver(model, train_loader, val_loader, criterion, lr, optimizer, writer=writer)
solver.train(epochs=100)
writer.close()

Running the training loop


100%|██████████| 67523/67523 [16:07<00:00, 69.78it/s]


Validating the model


100%|██████████| 22508/22508 [05:45<00:00, 65.13it/s]


(Epoch 1 / 100) train loss: 2.926782; val loss: 2.926834


100%|██████████| 67523/67523 [16:31<00:00, 68.10it/s] 


Validating the model


100%|██████████| 22508/22508 [05:28<00:00, 68.52it/s]


(Epoch 2 / 100) train loss: 2.631352; val loss: 2.633014


100%|██████████| 67523/67523 [16:31<00:00, 68.12it/s]


Validating the model


100%|██████████| 22508/22508 [05:27<00:00, 68.83it/s]


(Epoch 3 / 100) train loss: 2.628381; val loss: 2.631808


100%|██████████| 67523/67523 [16:26<00:00, 68.43it/s]


Validating the model


100%|██████████| 22508/22508 [05:24<00:00, 69.26it/s]


(Epoch 4 / 100) train loss: 2.629849; val loss: 2.631140


100%|██████████| 67523/67523 [16:19<00:00, 68.92it/s]


Validating the model


100%|██████████| 22508/22508 [05:24<00:00, 69.36it/s]


(Epoch 5 / 100) train loss: 2.629908; val loss: 2.629723


100%|██████████| 67523/67523 [16:17<00:00, 69.09it/s]


Validating the model


100%|██████████| 22508/22508 [05:23<00:00, 69.54it/s]


(Epoch 6 / 100) train loss: 2.630725; val loss: 2.632274


100%|██████████| 67523/67523 [16:24<00:00, 68.60it/s]  


Validating the model


100%|██████████| 22508/22508 [05:23<00:00, 69.61it/s]


(Epoch 7 / 100) train loss: 2.633233; val loss: 2.640430


100%|██████████| 67523/67523 [16:16<00:00, 69.17it/s]


Validating the model


100%|██████████| 22508/22508 [05:22<00:00, 69.83it/s]


(Epoch 8 / 100) train loss: 2.631522; val loss: 2.630884


100%|██████████| 67523/67523 [16:18<00:00, 69.04it/s]


Validating the model


100%|██████████| 22508/22508 [05:24<00:00, 69.41it/s]


(Epoch 9 / 100) train loss: 2.631349; val loss: 2.627547


100%|██████████| 67523/67523 [16:19<00:00, 68.93it/s]


Validating the model


100%|██████████| 22508/22508 [05:23<00:00, 69.61it/s]


(Epoch 10 / 100) train loss: 2.629255; val loss: 2.639483


100%|██████████| 67523/67523 [16:20<00:00, 68.85it/s]


Validating the model


100%|██████████| 22508/22508 [05:22<00:00, 69.70it/s]


(Epoch 11 / 100) train loss: 2.633057; val loss: 2.618521


100%|██████████| 67523/67523 [16:18<00:00, 69.03it/s]


Validating the model


100%|██████████| 22508/22508 [05:22<00:00, 69.75it/s]


(Epoch 12 / 100) train loss: 2.631567; val loss: 2.631134


100%|██████████| 67523/67523 [16:18<00:00, 68.99it/s]


Validating the model


100%|██████████| 22508/22508 [05:22<00:00, 69.85it/s]


(Epoch 13 / 100) train loss: 2.629468; val loss: 2.631172


100%|██████████| 67523/67523 [16:16<00:00, 69.17it/s]


Validating the model


100%|██████████| 22508/22508 [05:22<00:00, 69.85it/s]


(Epoch 14 / 100) train loss: 2.629866; val loss: 2.629902


100%|██████████| 67523/67523 [16:16<00:00, 69.13it/s]


Validating the model


100%|██████████| 22508/22508 [05:22<00:00, 69.69it/s]


(Epoch 15 / 100) train loss: 2.629492; val loss: 2.634307


100%|██████████| 67523/67523 [16:17<00:00, 69.08it/s]


Validating the model


100%|██████████| 22508/22508 [05:23<00:00, 69.63it/s]


(Epoch 16 / 100) train loss: 2.629440; val loss: 2.622425


100%|██████████| 67523/67523 [16:18<00:00, 69.04it/s]


Validating the model


100%|██████████| 22508/22508 [05:22<00:00, 69.84it/s]


(Epoch 17 / 100) train loss: 2.632490; val loss: 2.621897


100%|██████████| 67523/67523 [16:18<00:00, 68.99it/s]


Validating the model


100%|██████████| 22508/22508 [05:21<00:00, 70.02it/s]


(Epoch 18 / 100) train loss: 2.629007; val loss: 2.622374


100%|██████████| 67523/67523 [16:19<00:00, 68.91it/s]


Validating the model


100%|██████████| 22508/22508 [05:21<00:00, 69.90it/s]


(Epoch 19 / 100) train loss: 2.628977; val loss: 2.641817


100%|██████████| 67523/67523 [16:18<00:00, 69.02it/s]


Validating the model


100%|██████████| 22508/22508 [05:22<00:00, 69.79it/s]


(Epoch 20 / 100) train loss: 2.632374; val loss: 2.633086


100%|██████████| 67523/67523 [16:17<00:00, 69.06it/s]


Validating the model


100%|██████████| 22508/22508 [05:21<00:00, 70.07it/s]


(Epoch 21 / 100) train loss: 2.630613; val loss: 2.638321


100%|██████████| 67523/67523 [16:17<00:00, 69.08it/s]


Validating the model


100%|██████████| 22508/22508 [05:22<00:00, 69.74it/s]


(Epoch 22 / 100) train loss: 2.630618; val loss: 2.627060


100%|██████████| 67523/67523 [16:17<00:00, 69.09it/s]


Validating the model


100%|██████████| 22508/22508 [05:45<00:00, 65.14it/s]


(Epoch 23 / 100) train loss: 2.631845; val loss: 2.630183


100%|██████████| 67523/67523 [16:18<00:00, 68.98it/s]


Validating the model


100%|██████████| 22508/22508 [05:23<00:00, 69.64it/s]


(Epoch 24 / 100) train loss: 2.631172; val loss: 2.626233


100%|██████████| 67523/67523 [16:19<00:00, 68.91it/s]


Validating the model


100%|██████████| 22508/22508 [05:22<00:00, 69.73it/s]


(Epoch 25 / 100) train loss: 2.630315; val loss: 2.633946


100%|██████████| 67523/67523 [16:12<00:00, 69.40it/s]


Validating the model


100%|██████████| 22508/22508 [05:20<00:00, 70.12it/s]


(Epoch 26 / 100) train loss: 2.630819; val loss: 2.640781


100%|██████████| 67523/67523 [16:16<00:00, 69.17it/s]


Validating the model


100%|██████████| 22508/22508 [05:21<00:00, 69.97it/s]


(Epoch 27 / 100) train loss: 2.631715; val loss: 2.629423


100%|██████████| 67523/67523 [16:15<00:00, 69.21it/s]


Validating the model


100%|██████████| 22508/22508 [05:23<00:00, 69.53it/s]


(Epoch 28 / 100) train loss: 2.632710; val loss: 2.633877


100%|██████████| 67523/67523 [16:17<00:00, 69.09it/s]


Validating the model


100%|██████████| 22508/22508 [05:21<00:00, 69.93it/s]


(Epoch 29 / 100) train loss: 2.631031; val loss: 2.621329


100%|██████████| 67523/67523 [16:15<00:00, 69.22it/s]


Validating the model


100%|██████████| 22508/22508 [05:20<00:00, 70.22it/s]


(Epoch 30 / 100) train loss: 2.629456; val loss: 2.638918


100%|██████████| 67523/67523 [16:14<00:00, 69.31it/s]


Validating the model


100%|██████████| 22508/22508 [05:21<00:00, 70.02it/s]


(Epoch 31 / 100) train loss: 2.626720; val loss: 2.624118


100%|██████████| 67523/67523 [16:12<00:00, 69.43it/s]


Validating the model


100%|██████████| 22508/22508 [05:22<00:00, 69.72it/s]


(Epoch 32 / 100) train loss: 2.629059; val loss: 2.634896


100%|██████████| 67523/67523 [16:21<00:00, 68.78it/s]


Validating the model


100%|██████████| 22508/22508 [05:19<00:00, 70.55it/s]


(Epoch 33 / 100) train loss: 2.630606; val loss: 2.627949


100%|██████████| 67523/67523 [16:13<00:00, 69.38it/s]


Validating the model


100%|██████████| 22508/22508 [05:24<00:00, 69.36it/s]


(Epoch 34 / 100) train loss: 2.630935; val loss: 2.623745


100%|██████████| 67523/67523 [16:27<00:00, 68.36it/s]


Validating the model


100%|██████████| 22508/22508 [05:24<00:00, 69.32it/s]


(Epoch 35 / 100) train loss: 2.631285; val loss: 2.624766


100%|██████████| 67523/67523 [17:27<00:00, 64.45it/s]


Validating the model


100%|██████████| 22508/22508 [05:29<00:00, 68.22it/s]


(Epoch 36 / 100) train loss: 2.633591; val loss: 2.626733


 77%|███████▋  | 51808/67523 [12:35<03:49, 68.53it/s]


KeyboardInterrupt: 

In [None]:
# solver.best_params

OSError: "/bin/zsh" shell not found