In [None]:
import os
from os import path
from collections import deque
import random
from random import shuffle

import matplotlib.pyplot as plt

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from CustomDatasets import GameMemoryDataset,PlayerMemoryDataset

from CardGames import Hearts

CURRENT_DIR = path.abspath(path.curdir)
NAME_WHEEL = ['Lain','Turing','Tesla','Silver']
BATCH_SIZE = 1024
MINIBATCH_SIZE = 64
DECAY = 0.999

game_counter = 0
batch_counter = 0
h_list = []
ai_list = []


game = Hearts(train=True)

game_model_directory = os.listdir(CURRENT_DIR + f'/Models/Gamma/')

if len(game_model_directory)==0:

        game.version = 1
        torch.save(game.strat_nnet.state_dict(),CURRENT_DIR + f'/Models/Gamma/GNet {game.version}')

else:
    game.version = np.max(np.array([int(model.split(' ')[1]) for model in game_model_directory]))
    game.strat_nnet.load_state_dict(torch.load(CURRENT_DIR + f'/Models/Gamma/GNet {game.version}'))



game.optimizer = torch.optim.Adam(game.strat_nnet.parameters(),lr=0.0001)

for i,player in enumerate(game.playerlist):
    
    player.name = NAME_WHEEL[i]
    player.optimizer = torch.optim.Adam(player.response_nnet.parameters(),lr=0.0001)
    
    player_model_directory = os.listdir(CURRENT_DIR + f'/Models/{player.name}/Versions/')

    if len(player_model_directory)==0:

        player.version = 1
        torch.save(player.response_nnet.state_dict(),CURRENT_DIR + f'/Models/{player.name}/Versions/QNet {player.version}')

    else:
        player.version = np.max(np.array([int(model.split(' ')[1]) for model in player_model_directory]))
        player.response_nnet.load_state_dict(torch.load(CURRENT_DIR + f'/Models/{player.name}/Versions/QNet {player.version}'))

    if "QNet Prime" in os.listdir(CURRENT_DIR + f'/Models/{player.name}/'):
        
        player.best_response_nnet.load_state_dict(torch.load(CURRENT_DIR + f'/Models/{player.name}/QNet Prime'))

    else:

        torch.save(player.response_nnet.state_dict(),CURRENT_DIR + f'/Models/{player.name}/QNet Prime')



while batch_counter<3999:

    while not game.round_over:

        game.play_round()

    
    game_counter += 1
    game.reset()
    shuffle(game.playerlist)
    game.EPSILON = max(game.EPSILON*DECAY,0.05)

    if len(game.memory)>game.MIN_MEM_LEN and game_counter%1000==0:

        batch_counter += 1


        batch = random.sample(list(game.memory),BATCH_SIZE)
        dataset = GameMemoryDataset(batch)
        dataloader = DataLoader(
                        dataset, 
                        batch_size=MINIBATCH_SIZE,
                        shuffle=True
                        )
    
        for game_state,memory,mask,action in dataloader:

            prediction = game.strat_nnet(game_state,memory)
            

            game.optimizer.zero_grad()

            output = torch.mean((prediction*mask-action)**2)

            output.backward()       
            game.optimizer.step()  

        game.version += 1

        
        torch.save(game.strat_nnet.state_dict(),CURRENT_DIR + f'/Models/Gamma/GNet {game.version}')

        for player in game.playerlist:
        
            batch = random.sample(player.memory,BATCH_SIZE)
            dataset = PlayerMemoryDataset(batch)
            dataloader = DataLoader(
                            dataset, 
                            batch_size=MINIBATCH_SIZE,
                            shuffle=True
                            )
            
            for game_state, memory, action, reward, maxq in dataloader:

                prediction = player.response_nnet(game_state,memory)

                player.optimizer.zero_grad()


                loss = torch.mean((reward + maxq - torch.sum(prediction*action))**2)
                loss.backward()

                player.optimizer.step()

            player.version += 1

            torch.save(player.response_nnet.state_dict(),CURRENT_DIR + f'/Models/{player.name}/Versions/QNet {player.version}')
            torch.save(player.response_nnet.state_dict(),CURRENT_DIR + f'/Models/{player.name}/QNet Prime')  
            player.best_response_nnet.load_state_dict(torch.load(CURRENT_DIR + f'/Models/{player.name}/QNet Prime'))          

        print(f'batch {batch_counter} finished')    
            
        if batch_counter%5==0:
            
            eval_game = Hearts()
            eval_game.strat_nnet.load_state_dict(torch.load(CURRENT_DIR + f'/Models/Gamma/GNet {game.version}'))             
            
            eval_game.playerlist[0].model = "AI"
            eval_game.playerlist[1].model = "AI"
            
            for _ in range(10):
                
                for player in eval_game.playerlist:
                    
                    if player.model=="AI":
                        
                        player.name = np.random.choice(NAME_WHEEL)
                        player.response_nnet.load_state_dict(torch.load(CURRENT_DIR + f'/Models/{player.name}/Versions/QNet {game.playerlist[0].version}'))

                shuffle(eval_game.playerlist)
            
                
                while not eval_game.round_over:
                    
                    eval_game.play_round()
                    
                eval_game.reset()
                
            h_points = 0
            ai_points = 0                
                
            for player in eval_game.playerlist:

                if player.model=="Heuristic":

                    h_points += player.points

                else:

                    ai_points += player.points
                        
            h_list.append(h_points)
            ai_list.append(ai_points)
            
plt.plot(h_list,"red")
plt.plot(ai_list,"green")

batch 1 finished
batch 2 finished
batch 3 finished
batch 4 finished
batch 5 finished
batch 6 finished
batch 7 finished
batch 8 finished
batch 9 finished
batch 10 finished
batch 11 finished
batch 12 finished
batch 13 finished
batch 14 finished
batch 15 finished
batch 16 finished
batch 17 finished
batch 18 finished
batch 19 finished
batch 20 finished
batch 21 finished
batch 22 finished
batch 23 finished
batch 24 finished
batch 25 finished
batch 26 finished
batch 27 finished
batch 28 finished
batch 29 finished
batch 30 finished
batch 31 finished
batch 32 finished
batch 33 finished
batch 34 finished
batch 35 finished
batch 36 finished
batch 37 finished
batch 38 finished
batch 39 finished
batch 40 finished
batch 41 finished
batch 42 finished
batch 43 finished
batch 44 finished
batch 45 finished
batch 46 finished
batch 47 finished
batch 48 finished
batch 49 finished
batch 50 finished
batch 51 finished
batch 52 finished
batch 53 finished
batch 54 finished
batch 55 finished
batch 56 finished
b

batch 438 finished
batch 439 finished
batch 440 finished
batch 441 finished
batch 442 finished
batch 443 finished
batch 444 finished
batch 445 finished
batch 446 finished
batch 447 finished
batch 448 finished
batch 449 finished
batch 450 finished
batch 451 finished
batch 452 finished
batch 453 finished
batch 454 finished
batch 455 finished
batch 456 finished
batch 457 finished
batch 458 finished
batch 459 finished
batch 460 finished
batch 461 finished
batch 462 finished
batch 463 finished
batch 464 finished
batch 465 finished
batch 466 finished
batch 467 finished
batch 468 finished
batch 469 finished
batch 470 finished
batch 471 finished
batch 472 finished
batch 473 finished
batch 474 finished
batch 475 finished
batch 476 finished
batch 477 finished
batch 478 finished
batch 479 finished
batch 480 finished
batch 481 finished
batch 482 finished
batch 483 finished
batch 484 finished
batch 485 finished
batch 486 finished
batch 487 finished
batch 488 finished
batch 489 finished
batch 490 fi