In [1]:
import sys
sys.path.append('/content/drive/MyDrive/ColabNotebooks/nyanko_MLP')

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts


#自作モジュール
from features.FeatureFunction2 import (
    AttackFeature_solo,AttackFeature_mass,
    DefenseFeature_solo,DefenseFeature_mass,
    DisturbFeature_solo,DisturbFeature_mass,Wave_DisturbFeature,
    CostFeature,AttributeFeature,
    WPOFeature,RoleFeature2,
    RangeFeature,HandleFeature,
    GimmickFeature,AtFreFeature
)
from features.FeatureFunctionSpirit2 import(
    AttackFeatureSpirit,
    DefenseFeatureSpirit,
    DisturbFeatureSpirit,
    WPOFeatureSpirit,
)
from model.Spirit_router import SpiritRouter
from normalizer.base import StandardScaler
from model.mlp import MLP,FullModel
from model.feature_concat import FeatureConcat
from data.Rawdata import Rawdata
from model.normalizedFM import NormalizedFeatureModel
from data.SaveManagerCosine import SaveManager
from data.NyankoDataset import NyankoDataset
from data.FeatureEncorder import FeatureEncorder,TensorEncorder


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:",device)

features_normal = [
    AttackFeature_solo(device),
    AttackFeature_mass(device),
    DefenseFeature_solo(device),
    DefenseFeature_mass(device),
    DisturbFeature_solo(device),
    DisturbFeature_mass(device),
    Wave_DisturbFeature(device),
    CostFeature(device),
    AttributeFeature(device),
    WPOFeature(device),
    RoleFeature2(device),
    RangeFeature(device),
    HandleFeature(device),
    GimmickFeature(device),
    AtFreFeature(device)
]

features_spirit = [
    AttackFeature_solo(device),
    AttackFeatureSpirit(device),
    DefenseFeature_solo(device),
    DefenseFeatureSpirit(device),
    DisturbFeature_solo(device),
    DisturbFeatureSpirit(device),
    Wave_DisturbFeature(device),
    CostFeature(device),
    AttributeFeature(device),
    WPOFeatureSpirit(device),
    RoleFeature2(device),
    RangeFeature(device),
    HandleFeature(device),
    GimmickFeature(device),
    AtFreFeature(device)
]


rawdata = Rawdata("/content/drive/MyDrive/ColabNotebooks/nyanko_MLP/data/nyanko_DB100_dummy.xlsx")
feature_encoder = FeatureEncorder()
tensor_encoder = TensorEncorder(device=device)  # CPU tensorのまま

dataset = NyankoDataset(rawdata, feature_encoder, tensor_encoder, target_col="評価値")

router = SpiritRouter(features_normal,features_spirit)

feature_concat = FeatureConcat(router)

#正規化(mu,sigmaは要計算)
with torch.no_grad():
  feats = []
  for row_dict,target in dataset.get_all_rows():
    row = row_dict.copy()
    x = feature_concat(row)
    feats.append(x)


  feats = torch.stack(feats)
  mu = feats.mean(dim=0).to(device)
  sigma = feats.std(dim=0).to(device)




normalizer = StandardScaler(mu,sigma).to(device)

#model構築
norm_feature_model = NormalizedFeatureModel(feature_concat,normalizer)

sample_row = rawdata.get_row_dict(0)
feature_encoder_sample = feature_encoder(sample_row)
tensor_encoder_sample = tensor_encoder(feature_encoder_sample)  # GPU tensorへ
feature_dim = feature_concat.feature_dim(tensor_encoder_sample)

print("feature_dim:",feature_dim)

mlp = MLP(
    input_dim=feature_dim,
    hidden_dim=28,
    p = 0.10
).to(device)

model = FullModel(norm_feature_model,mlp).to(device)

#optimizer/loss
optimizer = torch.optim.Adam(model.parameters(),lr=3e-4)
criterion = nn.MSELoss()

scheduler = CosineAnnealingWarmRestarts(
    optimizer,
    T_0=200,
    T_mult=2,
    eta_min=1e-5
)


def collate_fn(batch):
  rows,targets = zip(*batch)
  targets = torch.tensor(targets,dtype=torch.float32,device=device)

  # z-score 正規化
  y_mu = dataset.y_mu.to(device)
  y_sigma = dataset.y_sigma.to(device)
  targets = (targets - y_mu) / y_sigma
  return rows,targets

loader = DataLoader(
      dataset,
      batch_size=1,
      shuffle=True,
      collate_fn=collate_fn
)



save_manager = SaveManager(
    save_dir="/content/drive/MyDrive/ColabNotebooks/nyanko_MLP/weight_data",
    recent_k=5,
    mode="min",
)


epochs = 6200
for epoch in range(epochs):
    model.train()
    total_loss = 0.0

    is_best = False

    for i,(rows,targets) in enumerate(loader):
      optimizer.zero_grad()

      row = rows[0]
      preds = model(row)

      loss =criterion(preds,targets)


      loss.backward()
      optimizer.step()
      scheduler.step(epoch + i/len(loader))

      total_loss += loss.item()

    lr = optimizer.param_groups[0]["lr"]
    ave_loss = total_loss/len(loader)


    # -------------------------
    # 保存
    # -------------------------
    # -------------------------
    # extra_state（必要なときだけ）
    # -------------------------
    #mu,sigma: feature normalizer stats
    extra_state = {
        "y_mu": dataset.y_mu.clone(),
        "y_sigma": dataset.y_sigma.clone(),
        "feature_mu": mu.detach().clone(),
        "feature_sigma": sigma.detach().clone(),
    }

    # 1. last（毎epoch）
    save_manager.save_last(
        epoch=epoch,
        model=model,
        optimizer=optimizer,
        scheduler=scheduler,
        metric=ave_loss,
        lr=lr,
        extra_state=extra_state,
    )

    # 2. best 判定
    is_best = save_manager.save_if_best(
        epoch=epoch,
        model=model,
        optimizer=optimizer,
        scheduler=scheduler,
        metric=ave_loss,
        lr=lr,
        extra_state=extra_state,
    )

    # 3. best 更新時だけ recent に積む
    if is_best:
        save_manager.save_recent_best(
            epoch=epoch,
            model=model,
            optimizer=optimizer,
            scheduler=scheduler,
            metric=ave_loss,
            lr=lr,
            extra_state=extra_state,
        )

    # 4. lr が下限付近に来たら保存（任意）
    save_manager.save_if_lr_min(
        epoch=epoch,
        model=model,
        optimizer=optimizer,
        scheduler=scheduler,
        metric=ave_loss,
        lr=lr,
        extra_state=extra_state,
    )

    # -------------------------
    # log
    # -------------------------
    if is_best :
      print(f"BEST loss | {ave_loss:.4f}")
    if epoch % 10 == 0 :
      print( f"epoch {epoch:4d}  | total_loss {total_loss:.4f} | ave_loss {ave_loss:.4f} | lr {lr:.2e}")





device: cpu
feature_dim: 38
BEST loss | 1.0314
epoch    0  | total_loss 121.7093 | ave_loss 1.0314 | lr 3.00e-04
BEST loss | 0.9651
BEST loss | 0.8858
BEST loss | 0.8656
BEST loss | 0.8094
BEST loss | 0.7766
BEST loss | 0.7592
BEST loss | 0.7297
BEST loss | 0.7285
BEST loss | 0.6826
BEST loss | 0.6707
epoch   10  | total_loss 79.1473 | ave_loss 0.6707 | lr 2.98e-04
BEST loss | 0.6533
BEST loss | 0.6255
BEST loss | 0.5952
BEST loss | 0.5716
BEST loss | 0.5676
BEST loss | 0.5591
BEST loss | 0.5556
epoch   20  | total_loss 67.1665 | ave_loss 0.5692 | lr 2.92e-04
BEST loss | 0.5229
BEST loss | 0.5164
BEST loss | 0.5159
BEST loss | 0.5004
BEST loss | 0.4644
BEST loss | 0.4553
epoch   30  | total_loss 55.1188 | ave_loss 0.4671 | lr 2.83e-04
BEST loss | 0.4448
BEST loss | 0.4337
BEST loss | 0.4211
BEST loss | 0.4173
epoch   40  | total_loss 50.2985 | ave_loss 0.4263 | lr 2.71e-04
BEST loss | 0.4094
BEST loss | 0.3805
epoch   50  | total_loss 50.0042 | ave_loss 0.4238 | lr 2.56e-04
BEST loss |