In [1]:
import sys
sys.path.append('/content/drive/MyDrive/ColabNotebooks/nyanko_MLP')

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader


#自作モジュール
from features.FeatureFunction2 import (
    AttackFeature_solo,AttackFeature_mass,
    DefenseFeature_solo,DefenseFeature_mass,
    DisturbFeature_solo,DisturbFeature_mass,Wave_DisturbFeature,
    CostFeature,AttributeFeature,
    WPOFeature,RoleFeature,
    RangeFeature,HandleFeature,
    GimmickFeature,AtFreFeature
)
from features.FeatureFunctionSpirit2 import(
    AttackFeatureSpirit,
    DefenseFeatureSpirit,
    DisturbFeatureSpirit,
    WPOFeatureSpirit,
)
from model.Spirit_router import SpiritRouter
from normalizer.base import StandardScaler
from model.mlp import MLP,FullModel
from model.feature_concat import FeatureConcat
from data.Rawdata import Rawdata
from model.normalizedFM import NormalizedFeatureModel
from data.SaveManager import SaveManager
from data.NyankoDataset import NyankoDataset
from data.FeatureEncorder import FeatureEncorder,TensorEncorder

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:",device)

features_normal = [
    AttackFeature_solo(device),
    AttackFeature_mass(device),
    DefenseFeature_solo(device),
    DefenseFeature_mass(device),
    DisturbFeature_solo(device),
    DisturbFeature_mass(device),
    Wave_DisturbFeature(device),
    CostFeature(device),
    AttributeFeature(device),
    WPOFeature(device),
    RoleFeature(device),
    RangeFeature(device),
    HandleFeature(device),
    GimmickFeature(device),
    AtFreFeature(device)
]

features_spirit = [
    AttackFeature_solo(device),
    AttackFeatureSpirit(device),
    DefenseFeature_solo(device),
    DefenseFeatureSpirit(device),
    DisturbFeature_solo(device),
    DisturbFeatureSpirit(device),
    Wave_DisturbFeature(device),
    CostFeature(device),
    AttributeFeature(device),
    WPOFeatureSpirit(device),
    RoleFeature(device),
    RangeFeature(device),
    HandleFeature(device),
    GimmickFeature(device),
    AtFreFeature(device)
]


rawdata = Rawdata("/content/drive/MyDrive/ColabNotebooks/nyanko_MLP/data/nyanko_DB100_dummy.xlsx")
feature_encoder = FeatureEncorder()
tensor_encoder = TensorEncorder(device=device)  # CPU tensorのまま

dataset = NyankoDataset(rawdata, feature_encoder, tensor_encoder, target_col="評価値")

router = SpiritRouter(features_normal,features_spirit)

feature_concat = FeatureConcat(router)

#正規化(mu,sigmaは要計算)
with torch.no_grad():
  feats = []
  for row_dict,target in dataset.get_all_rows():
    row = row_dict.copy()
    x = feature_concat(row)
    feats.append(x)


  feats = torch.stack(feats)
  mu = feats.mean(dim=0).to(device)
  sigma = feats.std(dim=0).to(device)




normalizer = StandardScaler(mu,sigma).to(device)

#model構築
norm_feature_model = NormalizedFeatureModel(feature_concat,normalizer)

sample_row = rawdata.get_row_dict(0)
feature_encoder_sample = feature_encoder(sample_row)
tensor_encoder_sample = tensor_encoder(feature_encoder_sample)  # GPU tensorへ
feature_dim = feature_concat.feature_dim(tensor_encoder_sample)

print("feature_dim:",feature_dim)

mlp = MLP(
    input_dim=feature_dim,
    hidden_dim=36,
    p = 0.15
).to(device)

model = FullModel(norm_feature_model,mlp).to(device)

#optimizer/loss
optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)
criterion = nn.MSELoss()


def collate_fn(batch):
  rows,targets = zip(*batch)
  targets = torch.tensor(targets,dtype=torch.float32,device=device)

  # z-score 正規化
  y_mu = dataset.y_mu.to(device)
  y_sigma = dataset.y_sigma.to(device)
  targets = (targets - y_mu) / y_sigma
  return rows,targets

loader = DataLoader(
      dataset,
      batch_size=1,
      shuffle=True,
      collate_fn=collate_fn
)

best_loss = float("inf")
saver = SaveManager("/content/drive/MyDrive/ColabNotebooks/nyanko_MLP/weight_data")

Best_Start_Epoch = 400
Min_Improvement = 0.3
Max_Best_Saves = 5

best_saved = 0

epochs = 701
for epoch in range(epochs):
    model.train()
    total_loss = 0.0

    ema_loss = None
    alpha = 0.1

    is_best = False

    for rows,targets in loader:
      optimizer.zero_grad()

      #rows:list[dict], len =batch_size
      preds = []
      for row in rows:
        pred = model(row)
        preds.append(pred)

      preds = torch.stack(preds).squeeze(-1)

      loss =criterion(preds,targets)

      loss.backward()
      optimizer.step()

      total_loss += loss.item()

      if ema_loss is None:
         ema_loss = total_loss

      else:
         ema_loss = alpha*total_loss +(1-alpha)*ema_loss

    if epoch % 10 == 0 :
      print( f"epoch {epoch:4d}  | loss {total_loss:.4f}")

    if epoch >= Best_Start_Epoch:


      if ema_loss < best_loss - Min_Improvement:
         best_loss = ema_loss
         is_best = True
         best_saved += 1

    if best_saved > Max_Best_Saves:
       is_best = False

    if saver.should_save(epoch,is_best):
       saver.save_model(
           model,
           epoch,
           total_loss,
           is_best=is_best,
           extra_state={
               "y_mu":dataset.y_mu.clone(),
               "y_sigma":dataset.y_sigma.clone(),
               "feature_mu":mu.detach().clone(),
               "feature_sigma":sigma.detach().clone(),
           }
       )

print("final_loss:",total_loss)






device: cpu
feature_dim: 43
epoch    0  | loss 92.6630
epoch   10  | loss 27.9176
epoch   20  | loss 18.2565
epoch   30  | loss 15.0246
epoch   40  | loss 12.9937
epoch   50  | loss 9.6637
epoch   60  | loss 9.6539
epoch   70  | loss 7.9221
epoch   80  | loss 9.8619
epoch   90  | loss 5.1924
epoch  100  | loss 7.8838
epoch  110  | loss 7.6580
epoch  120  | loss 6.5332
epoch  130  | loss 5.6943
epoch  140  | loss 7.9051
epoch  150  | loss 5.8944
epoch  160  | loss 5.4550
epoch  170  | loss 3.8458
epoch  180  | loss 4.5040
epoch  190  | loss 4.7732
epoch  200  | loss 4.4645
epoch  210  | loss 3.5383
epoch  220  | loss 3.6925
epoch  230  | loss 3.5206
epoch  240  | loss 5.4194
epoch  250  | loss 5.6092
epoch  260  | loss 5.5987
epoch  270  | loss 4.0138
epoch  280  | loss 4.3467
epoch  290  | loss 3.5849
epoch  300  | loss 3.7619
epoch  310  | loss 4.6209
epoch  320  | loss 4.6544
epoch  330  | loss 3.4435
epoch  340  | loss 3.6717
epoch  350  | loss 2.9177
epoch  360  | loss 3.2811
epoch