In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import DataLoader
from torchvision import datasets

%load_ext tensorboard
import tensorflow as tf
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter

import os
import json
import numpy as np
import random
import time

2023-05-30 20:44:58.373769: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
gpu = torch.cuda.is_available()
device = torch.device("cuda:0" if gpu else "cpu")
print("device:", device)

device: cuda:0


In [3]:
ICards = {'Bash':True,
         'Defend':False,
         'Strike':True,
         'Anger':True,
         'Armaments':False,
         'Body Slam':True,
         'Clash':True,
         'Cleave':False,
         'Clothesline':True,
         'Flex':False,
         'Havoc':False,
         'Headbutt':True,
         'Heavy Blade':True,
         'Iron Wave':True,
         'Perfected Strike':True,
         'Pommel Strike':True,
         'Shrug It Off':False,
         'Sword Boomerang':False,
         'Thunderclap':False,
         'True Grit':False,
         'Twin Strike':True,
         'Warcry':False,
         'Wild Strike':True,
         'Battle Trance':False,
         'Blood for Blood':True,
         'Bloodletting':False,
         'Burning Pact':False,
         'Carnage':True,
         'Combust':False,
         'Dark Embrace':False,
         'Disarm':True,
         'Dropkick':True,
         'Dual Wield':False,
         'Entrench':False,
         'Evolve':False,
         'Feel No Pain':False,
         'Fire Breathing':False,
         'Flame Barrier':False,
         'Ghostly Armor':False,
         'Hemokinesis':True,
         'Infernal Blade':False,
         'Inflame':False,
         'Intimidate':False,
         'Metallicize':False,
         'Power Through':False,
         'Pummel':True,
         'Rage':False,
         'Rampage':True,
         'Reckless Charge':True,
         'Rupture':False,
         'Searing Blow':True,
         'Second Wind':False,
         'Seeing Red':False,
         'Sentinel':False,
         'Sever Soul':True,
         'Shockwave':False,
         'Spot Weakness':True,
         'Uppercut':True,
         'Whirlwind':False,
         'Barricade':False,
         'Berserk':False,
         'Bludgeon':True,
         'Brutality':False,
         'Corruption':False,
         'Demon Form':False,
         'Double Tap':False,
         'Exhume':False,
         'Feed':True,
         'Fiend Fire':False,
         'Immolate':False,
         'Impervious':False,
         'Juggernaut':False,
         'Limit Break':False,
         'Offering':False,
         'Reaper':False
         }

GCards = {'Bandage Up':False,
         'Blind':True,
         'Dark Shackles':True,
         'Deep Breath':False,
         'Discovery':False,
         'Dramatic Entrance':False,
         'Enlightenment':False,
         'Finesse':False,
         'Flash of Steel':True,
         'Forethought':False,
         'Good Instincts':False,
         'Impatience':False,
         'Jack of All Trades':False,
         'Madness':False,
         'Mind Blast':True,
         'Panacea':False,
         'Panic Button':False,
         'Purity':False,
         'Swift Strike':True,
         'Trip':True,
         'Apotheosis':False,
         'Chrysalis':False,
         'Hand of Greed':True,
         'Magnetism':False,
         'Master of Strategy':False,
         'Mayhem':False,
         'Metamorphosis':False,
         'Panache':False,
         'Sadistic Nature':False,
         'Secret Technique':False,
         'Secret Weapon':False,
         'The Bomb':False,
         'Thinking Ahead':False,
         'Transmutation':False,
         'Violence':False,
         'Apparition':False,
         'Bite':True,
         'Expunger':True,
         'Insight':False,
         'J.A.X.':False,
         'Ritual Dagger':True,
         'Safety':False,
         'Smite':True,
         'Through Violence':True,
         'Slimed':False}

BCards = ['Burn',
          'Dazed',
          'Wound',
          'Void',
          "Ascender's Bane",
          'Clumsy',
          'Curse of the Bell',
          'Decay',
          'Doubt',
          'Injury',
          'Necronomicurse',
          'Normality',
          'Pain',
          'Parasite',
          'Pride',
          'Regret',
          'Shame',
          'Writhe']

playable = ICards | GCards
cardNames = list(playable.keys()) + BCards
cardIDs = dict(zip(cardNames, range(len(cardNames))))

neutrals = ['Dazed',
            'Wound',
            'Void',
            "Ascender's Bane",
            'Clumsy',
            'Curse of the Bell',
            'Injury',
            'Necronomicurse',
            'Writhe']

CARD_PRIORITY_LIST = [
    "Apotheosis",
    "Ghostly Armor",
    "Perfected Strike",
    "Whirlwind",
    "Battle Trance",
    "Demon Form",
    "Rage",
    "Offering",
    "Impervious",
    "Immolate",
    "Limit Break",
    "Flame Barrier",
    "Master of Strategy",
    "Inflame",
    "Disarm",
    "Shrug it Off",
    "Double Tap",
    "Thunderclap",
    "Metallicize",
    "Pommel Strike",
    "Shockwave",
    "Uppercut",
    "J.A.X.",
    "Panic Button",
    "Flash of Steel",
    "Flex",
    "Anger",
    #"Skip",
    "Secret Weapon",
    "Finesse",
    "Mayhem",
    "Panache",
    "Secret Technique",
    "Metamorphosis",
    "Thinking Ahead",
    "Madness",
    "Discovery",
    "Chrysalis",
    "Deep Breath",
    "Trip",
    "Enlightenment",
    "Heavy Blade",
    "Feed",
    "Fiend Fire",
    "Twin Strike",
    "Headbutt",
    "Seeing Red",
    "Combust",
    "Clash",
    "Dark Shackles",
    "Sword Boomerang",
    "Dramatic Entrance",
    "Bludgeon",
    "Hand of Greed",
    "Evolve",
    "Violence",
    "Bite",
    "Carnage",
    "Clothesline",
    "Bash",
    "Bandage Up",
    "Panacea",
    "Reckless Charge",
    "Infernal Blade",
    "Spot Weakness",
    "Strike",
    #"Shiv",
    "Havoc",
    "Ritual Dagger",
    "Dropkick",
    "Feel No Pain",
    "Swift Strike",
    "Corruption",
    "Magnetism",
    "Bloodletting",
    "Iron Wave",
    "Armaments",
    "Mind Blast",
    "Ascender's Bane",
    "Dazed",
    "Void",
    "Rampage",
    "Ghostly Armor",
    "True Grit",
    "Blind",
    "Good Instincts",
    "Pummel",
    "Hemokinesis",
    "Exhume",
    "Reaper",
    "Cleave",
    "Warcry",
    "Purity",
    "Dual Wield",
    "Wild Strike",
    "Defend",
    "Body Slam",
    "Sever Soul",
    "Burning Pact",
    "Brutality",
    "Barricade",
    "Intimidate",
    "Juggernaut",
    "Sadistic Nature",
    "Dark Embrace",
    "Power Through",
    "Transmutation",
    "Sentinel",
    "Rupture",
    "Slimed",
    "Fire Breathing",
    "Second Wind",
    "Impatience",
    "The Bomb",
    "Jack of All Trades",
    "Searing Blow",
    "Blood for Blood",
    "Berserk",
    "Entrench",
    "Forethought",
    "Clumsy",
    "Parasite",
    "Shame",
    "Injury",
    "Wound",
    "Writhe",
    "Doubt",
    "Burn",
    "Decay",
    "Regret",
    "Necronomicurse",
    "Pain",
    "Normality",
    "Pride"
]

In [13]:
def card2tensor(card, played=False):
    upgraded = card.find('+') != -1
    card = card.replace('+', '')
    card = cardIDs[card]
    if played:
        card = F.one_hot(torch.tensor(card), len(playable)+1)
    else:
        card = F.one_hot(torch.tensor(card), len(cardNames)+1)
    if upgraded:
        card[-1] = 1
    return card

class DatReader():
    def __init__(self, start, end, shuffle, aug):
        self.start = start
        self.end = end
        self.shuffle = shuffle
        self.aug = aug
        
    def __iter__(self):
        dir_path = 'data/'
        self.dat_files = []
        for path in os.listdir(dir_path):
            if os.path.isfile(os.path.join(dir_path, path)):
                self.dat_files.append(path)
                
        random.Random(0).shuffle(self.dat_files)
                
        self.file = self.start
        self.dat = open('data/'+self.dat_files[self.file], 'r')
        
        self.prev = {'hp':75, 'maxhp':75, 'energy':3, 'floor':1}
            
        return self
    
    def __next__(self):
        x = self.dat.readline()
        while x == '' or x.strip() == 'done':
            self.file += 1
            if self.file == len(self.dat_files) or self.file == self.end:
                raise StopIteration
                
            self.dat = open('data/'+self.dat_files[self.file], 'r')
            x = self.dat.readline()
            self.prev = {'hp':75, 'maxhp':75, 'energy':3, 'floor':1}
        
        x = json.loads(x)
        
        label = card2tensor(x['card'], True)
        
        hand = x['hand'][:9]
        insertion = random.randint(0, len(hand))
        hand.insert(insertion, x['card'])
        if self.aug:
            hand = [c for c in hand if c not in neutrals]
            addnum = random.randint(0, 10-len(hand))
            adding = np.random.randint(len(neutrals), size=(addnum))
            adding = [neutrals[i] for i in adding]
            hand += adding
        if self.shuffle:
            random.shuffle(hand)
        state = [card2tensor(c) for c in hand]
        state = torch.cat(state, 0)
        padding = torch.zeros(((10-len(hand))*(len(cardNames)+1)))
        
        hp = x['hp'] if x['hp'] != -1 else self.prev['hp']
        maxhp = x['maxhp'] if x['maxhp'] != -1 else self.prev['maxhp']
        energy = x['energy'] if x['energy'] != -1 else (3 if self.prev['energy'] == 0 else self.prev['energy']-1)
        floor = x['floor'] if x['floor'] != -1 else self.prev['floor']
        v2p0 = 1 if x['v2.0'] else 0
        
        self.prev = x
        
        misc = torch.tensor([hp/100, maxhp/100, x['block']/100, energy/10, floor/100, v2p0])
        state = torch.cat((state, padding, misc), 0)
        
        return (state.float(), label.float())
    
class SpireData(torch.utils.data.IterableDataset):
    def __init__(self, start=0, end=np.inf, shuffle=False, aug=False):
        super(SpireData).__init__()
        self.start = start
        self.end = end
        self.shuffle = shuffle
        self.aug = aug

    def __iter__(self):
        return iter(DatReader(self.start, self.end, self.shuffle, self.aug))

def pred_acc(inputs, preds, labels):
    hands = tf.identity(inputs.cpu()).numpy()
    preds = preds.cpu().detach().numpy()
    labels = labels.cpu().detach().numpy()
    
    correct = 0
    t2 = 0
    samples = 0
    
    for i in range(len(labels)):
        label = labels[i]
        card1 = np.argmax(label[:-1])
        upgraded1 = label[-1] > 0.5

        state = hands[i]
        hand = []
        for j in range(10):
            c = state[j*(len(cardNames)+1):(j+1)*(len(cardNames)+1)]
            c2 = np.argmax(c[:-1])
            #not empty and playable
            if sum(c) > 0 and c2 < len(playable):
                if c[-1] > 0.5:
                    c2 += 118
                hand.append(c2)
        canplay = [c % 118 for c in hand]

        pred = preds[i]
        cards = pred[:-1]
        only_playable = lambda x,i : x if i in canplay else 0
        cards = np.array([only_playable(c,i) for i,c in enumerate(cards)])
        card2 = np.argmax(cards)
        
        cards[card2] = 0
        card3 = np.argmax(cards)

        upgraded2 = pred[-1] > 0.5
        upgradeChoice = card2 in hand and card2+118 in hand
        upgradeComp = upgradeChoice and card1 == card2 and upgraded1 == upgraded2
        simpleComp = not upgradeChoice and card1 == card2
        if simpleComp or upgradeComp:
            correct += 1
        approxComp = card1 == card2 or card1 == card3
        t2 += approxComp
        samples += 1
        
    return (correct, t2, samples)

def eval_acc(net, criterion):
    random.seed(0)
    ds = SpireData(start=0, end=100)
    loader = iter(DataLoader(ds, batch_size=100))
    
    correct = 0
    t2 = 0
    samples = 0
    total_loss = 0
    batches = 0
    
    for inputs,labels in loader:
        if gpu:
            inputs = inputs.to(device)
            labels = labels.to(device)

        preds = net(inputs)
        loss = criterion(preds, labels)
        total_loss += loss.item()
        batches += 1
        
        c,t,s = pred_acc(inputs, preds, labels)
        
        correct += c
        t2 += t
        samples += s
    
    #print(samples)
    return (correct/samples, t2/samples, total_loss/batches)
    
# t0 = time.perf_counter()
criterion = nn.BCELoss()
print(eval_acc(net, criterion))
# print(time.perf_counter() - t0)

# 51008
# (0.5588927227101631, 0.7916209222082811, 0.02493530709784908)

# best2:
# (0.5629901191969887, 0.7974239335006273, 0.4845576142881306)

(0.5629901191969887, 0.7974239335006273, 0.4845576142881306)


In [10]:
def rand_acc():
    ds = SpireData(start=0, end=100)
    
    sum_correct = 0.
    sum_close = 0.
    sum_t2 = 0.
    samples = 0
    
    for state, label in ds:
        state = state.numpy()
        hand = []
        for i in range(10):
            c = state[i*(len(cardNames)+1):(i+1)*(len(cardNames)+1)]
            c2 = np.argmax(c[:-1])
            #not empty and playable
            if sum(c) > 0 and c2 < len(playable):
                if c[-1] > 0.5:
                    c2 += len(playable)
                hand.append(c2)
            
        label = label.numpy()
        card = np.argmax(label[:-1])
        close = [(c % len(playable)) == card for c in hand]
        r_close = sum(close)/len(hand)
        sum_close += r_close
        
        sum_t2 += r_close
        
        if len(hand) > 1:
            r_t2 = sum(close)/(len(hand)-1)
            sum_t2 += (1 - r_close) * r_t2
        
        if label[-1] > 0.5:
            card += len(playable)
        correct = [c == card for c in hand]
        sum_correct += sum(correct)/len(hand)
        
        samples += 1
        
    print(sum_correct/samples)
    print(sum_close/samples)
    print(sum_t2/samples)
    
# 0.3445052074103245
# 0.34928241702886265

# 0.3414854226199696
# 0.3462408035626854
# 0.5859287931846844

# 0.3343652131754624
# 0.34034448066664974
# 0.5832226130729942

def rand_stoch():
    random.seed(0)
    ds = SpireData(start=0, end=100)
    
    correct = 0.
    t2 = 0.
    samples = 0
    
    for state, label in ds:
        state = state.numpy()
        hand = []
        for i in range(10):
            c = state[i*137:(i+1)*137]
            c2 = np.argmax(c[:-1])
            #not empty and playable
            if sum(c) > 0 and c2 < len(playable):
                if c[-1] > 0.5:
                    c2 += 118
                hand.append(c2)
            
        label = label.numpy()
        card1 = np.argmax(label[:-1])
        upgraded1 = label[-1] > 0.5
        
        card2 = hand[random.randint(0, len(hand)-1)]
        upgraded2 = card2 >= len(playable)
        
        hand2 = hand.copy()
        hand2.remove(card2)
        card2 = card2 % len(playable)
        card3 = card2
        if len(hand2) > 0:
            card3 = hand2[random.randint(0, len(hand2)-1)]
            
        upgradeChoice = card2 in hand and card2+118 in hand
        upgradeComp = upgradeChoice and card1 == card2 and upgraded1 == upgraded2
        simpleComp = not upgradeChoice and card1 == card2
        if simpleComp or upgradeComp:
            correct += 1
        approxComp = card1 == card2 or card1 == card3
        t2 += approxComp
        samples += 1
        
    print(correct/samples)
    print(t2/samples)
    print(samples)
    
# 0.3438662967724474
# 0.4849448541849922

def priority_acc():
    random.seed(0)
    ds = SpireData(start=100, end=130)
    
    correct = 0.
    t2 = 0.
    samples = 0
    
    for state, label in ds:
        state = state.numpy()
        hand = []
        for i in range(10):
            c = state[i*137:(i+1)*137]
            c2 = np.argmax(c[:-1])
            #not empty and playable
            if sum(c) > 0 and c2 < len(playable):
                if c[-1] > 0.5:
                    c2 += 118
                hand.append(c2)
        canplay = [c % len(playable) for c in hand]
            
        label = label.numpy()
        card1 = np.argmax(label[:-1])
        upgraded1 = label[-1] > 0.5
        
        card2 = 0
        for card in CARD_PRIORITY_LIST:
            if cardIDs[card] in canplay:
                card2 = cardIDs[card]
                canplay.remove(card2)
                break
        upgraded2 = True
        
        card3 = card2
        if len(canplay) > 0:
            for card in CARD_PRIORITY_LIST:
                if cardIDs[card] in canplay:
                    card3 = cardIDs[card]
                    break
            
        upgradeChoice = card2 in hand and card2+118 in hand
        upgradeComp = upgradeChoice and card1 == card2 and upgraded1 == upgraded2
        simpleComp = not upgradeChoice and card1 == card2
        if simpleComp or upgradeComp:
            correct += 1
        approxComp = card1 == card2 or card1 == card3
        t2 += approxComp
        samples += 1
        
    print(correct/samples)
    print(t2/samples)
    
# 0.41301847215643817
# 0.6545097773868327
    
def eval_stats(net):
    random.seed(0)
    ds = SpireData(start=100, end=130)
    loader = iter(DataLoader(ds, batch_size=100))
    
    correct = 0
    handlen = 0
    samples = 0
    
    rightCards = np.zeros(len(playable))
    wrongCards = np.zeros(len(playable))
    wrongGuesses = np.zeros(len(playable))
    t2right = np.zeros(len(playable))
    cardCount = np.zeros(len(playable))
    handCount = np.zeros(len(cardNames))
    
    playpos = []
    labelpos = []
    for i in range(2,11):
        playpos.append(np.zeros(i))
        labelpos.append(np.zeros(i))
    playpos2 = np.zeros(10)
    labelpos2 = np.zeros(10)
    lens = np.zeros(10)
    
    for inputs,labels in loader:
        if gpu:
            inputs = inputs.to(device)
            labels = labels.to(device)

        preds = net(inputs)
        
        hands = tf.identity(inputs.cpu()).numpy()
        preds = preds.cpu().detach().numpy()
        labels = labels.cpu().detach().numpy()

        for i in range(len(labels)):
            label = labels[i]
            card1 = np.argmax(label[:-1])
            cardCount[card1] += 1

            state = hands[i]
            fullhand = []
            hand = []
            for j in range(10):
                c = state[j*137:(j+1)*137]
                c2 = np.argmax(c[:-1])
                #not empty and playable
                if sum(c) > 0:
                    fullhand.append(c2 % len(playable))
                    handCount[c2] += 1
                    if c2 < len(playable):
                        hand.append(c2)
            handlen += len(np.unique(hand))
            lens[len(np.unique(hand))] += 1
            canplay = [c for c in hand]

            pred = preds[i]
            cards = pred[:-1]
            only_playable = lambda x,i : x if i in canplay else 0
            cards = np.array([only_playable(c,i) for i,c in enumerate(cards)])
            card2 = np.argmax(cards)
        
            cards[card2] = 0
            card3 = np.argmax(cards)
            
            if len(fullhand) > 1:
                labelpos[len(fullhand)-2] += (fullhand == card1)
                playpos[len(fullhand)-2] += (fullhand == card2)
                labelpos2 += np.pad(fullhand == card1, (0, 10-len(fullhand)))
                playpos2 += np.pad(fullhand == card2, (0, 10-len(fullhand)))

            simpleComp = card1 == card2
            if simpleComp:
                correct += 1
                rightCards[card1] += 1
            else:
                wrongCards[card1] += 1
                wrongGuesses[card2] += 1
                if card1 == card3:
                    t2right[card1] += 1
            samples += 1
            
    # print(playpos2)
    # print(labelpos2.astype('int'))
    # for i in range(len(playpos)):
    #     print("%d cards:" % (i+2))
    #     print(playpos[i])
    #     print(labelpos[i])
    # [3920. 3957. 3380. 2691. 1600.  624.  211.   42.    0.    0.]
    # [4003 4008 3404 2514 1512  570  180   45    2    0]
          
    #print(handlen/samples)
    #3.596183774274308 average unique handlen
    #return
    
    #print(lens.astype('int'))
    #[   0 1218 2220 3543 3769 2692 1071  231   33    2]
            
    with np.errstate(divide='ignore'):
        pp = cardCount/handCount[:len(playable)]
        rp = rightCards/cardCount
        wp = wrongCards/cardCount
        wgp = wrongGuesses/handCount[:len(playable)]
        tp = t2right/wrongCards
        # print("%d/%d = %.4f" % (correct, samples, correct/samples))
        # print(handlen/samples)
        # print([cardNames[i] for i in np.flip(np.argsort(handCount))])
        # print("hand count:")
        # print(handCount.astype('int'))
        # print("play count:")
        # print(cardCount.astype('int'))
        # print("play %")
        # print(pp)
        # print("right:")
        # print(rightCards.astype('int'))
        # print("wrong:")
        # print(wrongCards.astype('int'))
        # print("wrong guesses:")
        # print(wrongGuesses.astype('int'))
        # print("t2 right:")
        # print(t2right.astype('int'))
        # print("\nright %")
        # for i in np.flip(np.argsort(rp)):
        #     if cardCount.astype('int')[i] == 0:
        #         continue
        #     print("%.3f: %s (%d)" % (rp[i], cardNames[i], cardCount[i]))
        # print("\nwrong %")
        # for i in np.flip(np.argsort(wp)):
        #     if cardCount.astype('int')[i] == 0:
        #         continue
        #     print("%.3f: %s (%d)" % (wp[i], cardNames[i], cardCount[i]))
        # print(wp)
        # print("wrong guess %")
        # print(wgp)
        # print("t2 %:")
        # print(tp)
        #for i in np.flip(np.argsort(pp)):
        for i in range(len(playable)):
            print(cardNames[i] + ":")
            print("     played %5d / %5d  (%2.1f%%)" % (cardCount[i], handCount[i], pp[i]*100))
            print("    correct %5d / %5d  (%2.1f%%)" % (rightCards[i], cardCount[i], rp[i]*100))
            print("      wrong %5d / %5d  (%2.1f%%)" % (wrongCards[i], cardCount[i], wp[i]*100))
            print("  wr. guess %5d / %5d  (%2.1f%%)" % (wrongGuesses[i], cardCount[i], wgp[i]*100))
            print(" t2 instead %5d / %5d  (%2.1f%%)" % (t2right[i], wrongCards[i], tp[i]*100))
            print()
        
def sample_pred(net, interval):
    random.seed(0)
    ds = SpireData(start=100, end=130)
    loader = iter(DataLoader(ds, batch_size=1))
    
    for (i, (inputs,labels)) in enumerate(loader):
        if i % interval != 0:
            continue
            
        if gpu:
            inputs = inputs.to(device)
            labels = labels.to(device)

        preds = net(inputs)
        
        hands = tf.identity(inputs.cpu()).numpy()
        preds = preds.cpu().detach().numpy()
        labels = labels.cpu().detach().numpy()

        for i in range(len(labels)):
            label = labels[i]
            card1 = np.argmax(label[:-1])
            upgraded1 = label[-1] > 0.5

            state = hands[i]
            hand = []
            for j in range(10):
                c = state[j*137:(j+1)*137]
                c2 = np.argmax(c[:-1])
                #not empty and playable
                if sum(c) > 0 and c2 < len(playable):
                    if c[-1] > 0.5:
                        c2 += 118
                    hand.append(c2)
            canplay = [c % 118 for c in hand]

            pred = preds[i]
            cards = pred[:-1]
            only_playable = lambda x,i : x if i in canplay else 0
            cards = np.array([only_playable(c,i) for i,c in enumerate(cards)])
            card2 = np.argmax(cards)

            upgraded2 = pred[-1] > 0.5
            upgradeChoice = card2 in hand and card2+118 in hand
            upgradeComp = upgradeChoice and card1 == card2 and upgraded1 == upgraded2
            simpleComp = not upgradeChoice and card1 == card2
            print("state:" + str(state[-5:]))
            print("hand:")
            for c in hand:
                str1 = cardNames[c % len(playable)]
                if c >= len(playable):
                    str1 += '+'
                out = cards[c % len(playable)]
                print(" %.3f %s" % (out, str1))
            print("pred:  " + cardNames[card2])
            print("upggr: %.2f" % (pred[-1]))
            print("label: " + cardNames[card1])
            if simpleComp or upgradeComp:
                print("correct")
            else:
                print("wrong")
            print()
                
#sample_pred(net, 1000)
#eval_stats(net)
rand_acc()
#rand_stoch()
#priority_acc()

0.3343652131754624
0.34034448066664974
0.5832226130729942


In [5]:
writer = SummaryWriter('runs/221rl3')

In [10]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear((len(cardNames)+1) * 10 + 6, 2000)
        self.fc2 = nn.Linear(2000, 2000)
        self.fc3 = nn.Linear(2000, 1000)
        #self.fc4 = nn.Linear(1000, 1000)
        #self.fc5 = nn.Linear(100, 400)
        #self.fc6 = nn.Linear(100, 200)
        self.fc4 = nn.Linear(1000, len(playable)+1)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.relu(self.fc3(x))
        #x = self.relu(self.fc4(x))
        #x = self.relu(self.fc5(x))
        #x = self.relu(self.fc6(x))
        #x = self.relu(self.fc7(x))
        x = self.fc4(x)
        x = self.sigmoid(x)
        return x
    
class Net2(nn.Module):
    def __init__(self):
        super(Net2, self).__init__()
        self.fc1 = nn.Linear((len(cardNames)+1) * 10 + 6, 2000)
        self.fc2 = nn.Linear(2000, 2000)
        self.fc3 = nn.Linear(2000, 1000)
        self.fc4 = nn.Linear(1000, len(playable))
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        canplay = torch.tensor([-1000]).repeat((x.size()[0], len(playable))).float().to(device)
        for c in range(10):
            canplay += x[:,c*(len(cardNames)+1):c*(len(cardNames)+1)+len(playable)]*2000
        canplay = self.sigmoid(canplay)
        
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.relu(self.fc3(x))
        x = self.fc4(x)
        x = x * canplay
        x = self.softmax(x)
        x = torch.concat((x, torch.ones((x.size()[0], 1)).to(device)), 1)
        return x

net = Net2()
if(gpu):
    net.to(device)
print(net)

Net2(
  (fc1): Linear(in_features=1396, out_features=2000, bias=True)
  (fc2): Linear(in_features=2000, out_features=2000, bias=True)
  (fc3): Linear(in_features=2000, out_features=1000, bias=True)
  (fc4): Linear(in_features=1000, out_features=120, bias=True)
  (relu): ReLU()
  (sigmoid): Sigmoid()
  (softmax): Softmax(dim=1)
)


In [12]:
net.load_state_dict(torch.load('best2'))

<All keys matched successfully>

In [12]:
torch.save(net.state_dict(), '221aug600')

In [11]:
criterion = nn.BCELoss()
#optimizer = optim.Adam(net.parameters(), lr=0.07, eps=0.1)
optimizer = optim.Adam(net.parameters(), lr=0.2, eps=0.1)

printInterval = 200

epoch0 = 0

best_eval = 0.5789 #0.5824 #0.576 #0.5784
worst_eval = 0.08937 #0.0470

t0 = time.perf_counter()
for epoch in range(100):  # loop over the dataset multiple times
    random.seed(epoch0+epoch)
    
    ds = SpireData(start=130, shuffle=False, aug=False)
    batch_size = 100
    loader = iter(DataLoader(ds, batch_size=batch_size))

    correct = t2 = samples = 0
    epoch_loss = running_loss = 0.0
    for i, data in enumerate(loader, 0):
        # get the inputs
        if (gpu):
            inputs, labels = data[0].to(device), data[1].to(device)
        else:
            inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        if i*batch_size < 15000:
            c,t,s = pred_acc(inputs, outputs, labels)
            correct += c
            t2 += t
            samples += s

        # print statistics
        epoch_loss += loss.item()
        running_loss += loss.item()
        if i % printInterval == printInterval-1:
            print('[%d, %5d] loss: %.4f' %
                  (epoch + 1, i + 1, running_loss / printInterval))
            running_loss = 0.0
    
    print('[%d] loss: %.4f' % (epoch+1, epoch_loss/(i+1.)))
    writer.add_scalar("Loss/train", epoch_loss/(i+1.), epoch+epoch0)
    writer.add_scalar("Acc/train", correct/samples, epoch+epoch0)
    writer.add_scalar("T2/train", t2/samples, epoch+epoch0)
    
    acc, et2, eloss = eval_acc(net, criterion)
    print('[%d] acc: %.4f' % (epoch+1, acc))
    writer.add_scalar("Loss/eval", eloss, epoch+epoch0)
    writer.add_scalar("Acc/eval", acc, epoch+epoch0)
    writer.add_scalar("T2/eval", et2, epoch+epoch0)
    writer.flush()
    
    if acc > best_eval:
        best_eval = acc
        torch.save(net.state_dict(), 'best4')
        print("New record!")
    
    if acc < worst_eval:
        worst_eval = acc
        #torch.save(net.state_dict(), 'worst')
        print("New worst!")
        
    if acc > 0.4:
        torch.save(net.state_dict(), 'last')

print('Finished Training in %ds' % (time.perf_counter()-t0))

[1,   200] loss: 0.5194
[1,   400] loss: 0.5624
[1,   600] loss: 0.5019
[1,   800] loss: 0.4889
[1] loss: 0.5062
[1] acc: 0.4524
[2,   200] loss: 0.4894
[2,   400] loss: 0.5338
[2,   600] loss: 0.4811
[2,   800] loss: 0.4819
[2] loss: 0.4885
[2] acc: 0.4578
[3,   200] loss: 0.4882
[3,   400] loss: 0.5332
[3,   600] loss: 0.4805
[3,   800] loss: 0.4815
[3] loss: 0.4878
[3] acc: 0.4686
[4,   200] loss: 0.4878
[4,   400] loss: 0.5329
[4,   600] loss: 0.4803
[4,   800] loss: 0.4813
[4] loss: 0.4876
[4] acc: 0.4695
[5,   200] loss: 0.4877
[5,   400] loss: 0.5327
[5,   600] loss: 0.4801
[5,   800] loss: 0.4812
[5] loss: 0.4874
[5] acc: 0.4705
[6,   200] loss: 0.4876
[6,   400] loss: 0.5326
[6,   600] loss: 0.4800
[6,   800] loss: 0.4811
[6] loss: 0.4873
[6] acc: 0.4718
[7,   200] loss: 0.4875
[7,   400] loss: 0.5326
[7,   600] loss: 0.4799
[7,   800] loss: 0.4810
[7] loss: 0.4873
[7] acc: 0.4751
[8,   200] loss: 0.4874
[8,   400] loss: 0.5325
[8,   600] loss: 0.4798
[8,   800] loss: 0.4810
[

In [2]:
%tensorboard --logdir runs

In [43]:
from tensorflow.python.summary.summary_iterator import summary_iterator
for summary in tf.compat.v1.train.summary_iterator("runs/first-2421-BCE/events.out.tfevents.1680588554.harp.3125117.0"):
    if summary.step == 0:
        continue
    step = summary.step-50
    tag = summary.summary.value[0].ListFields()[0][1]
    val = summary.summary.value[0].ListFields()[1][1]
    writer.add_scalar(tag, val, step)