In [None]:
import torch
import torch.nn as nn
from agent import *
from env import *

# 1. Train Swing Up

In [None]:
actor_su = ActorSU(obs_dim= 6, act_dim= 1, max_action= 1.75)
actor_su.load_pretrained_weights()

In [None]:
q1_su = CriticSU(obs_dim= 6, act_dim= 1)
q2_su = CriticSU(obs_dim= 6, act_dim= 1)
q1_target_su = CriticSU(obs_dim= 6, act_dim= 1)
q2_target_su = CriticSU(obs_dim= 6, act_dim= 1)

In [None]:
env_su = InvertedPendulumSerialEnv(port='COM12', baudrate=921600, mode='swing_up')

In [None]:
trainer_su = SACTrainer(env= env_su, gamma= 0.995, tau= 0.005, initial_alpha= 1.0,
                     actor= actor_su, q1= q1_su, q2= q2_su, q1_target= q1_target_su, q2_target= q2_target_su,
                     lr_actor= 2e-4, lr_critic= 1e-3, lr_alpha= 1e-3)

In [None]:
trainer_su.train(episodes= 100, max_steps= 1000, window_length= 1, stop_avg_value= 200)
trainer_su.env.send_action('RESET')
trainer_su.env.close()

In [None]:
trainer_su.visualize_training()

# 2. Train Balance

In [None]:
actor_su = ActorSU(obs_dim= 6, act_dim= 1, max_action= 1.75)
actor_su.load_model()
actor_b = ActorB(obs_dim= 6, act_dim= 1, max_action= 5.0)

In [None]:
q1_b = CriticB(obs_dim= 6, act_dim= 1)
q2_b = CriticB(obs_dim= 6, act_dim= 1)
q1_target_b = CriticB(obs_dim= 6, act_dim= 1)
q2_target_b = CriticB(obs_dim= 6, act_dim= 1)

In [None]:
env_b = InvertedPendulumSerialEnv(port='COM12', baudrate=921600, mode='balance')

In [None]:
trainer_b = SACBalanceTrainer(env= env_b, gamma= 0.995, tau= 0.005, initial_alpha= 0.1,
                            actor= actor_b, q1= q1_b, q2= q2_b, q1_target= q1_target_b, q2_target= q2_target_b,
                            lr_actor= 1e-5, lr_critic= 1e-4, lr_alpha= 1e-4, 
                            actor_su= actor_su, load_model= 'sac_model_b.pth', load_pretrain= False) # 

In [None]:
trainer_b.train(episodes= 100, max_steps= 3000, window_length= 10, stop_avg_value= 3000)
trainer_b.env.send_action('RESET')
trainer_b.env.close()

In [None]:
trainer_b.visualize_training()