# Quick Start

## How to build a GPU-accelerated FDM
### 1. Download your data
eg. wget https://dept.aem.umn.edu/~./faculty/balas/darpa_sec/software/F16Simulation.tar.gz
Download your aerodata and rename the data file as 'data'

### 2. Train your MLP model
example code, full code see `./train_model/train_model.py  `

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score
import csv
from train_model.hifi_F16_AeroData import hifi_F16
device = "cuda:0"


class MLP(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_list):
        super().__init__()
        layers = []
        lastv = in_dim
        for hidden in hidden_list:
            layers.append(nn.Linear(lastv, hidden))
            layers.append(nn.ReLU())
            lastv = hidden
        layers.append(nn.Linear(lastv, out_dim))
        self.layers = nn.Sequential(*layers)
        self.out_dim = out_dim

    def forward(self, x):
        x = x.to(torch.float32)
        ret = self.layers(x)
        return ret


class MyDataset(Dataset):
    def __init__(self, input, output, transform=None):
        super().__init__()
        self.transform = transform
        self.input = input
        self.output = output
    
    def __getitem__(self, index):
        input = self.input[index]
        output = self.output[index]
        return input, output

    def __len__(self):
        return len(self.input)


def safe_read_dat(dat_name):
    try:
        path = r'./data/' + dat_name
        with open(path, 'r', encoding='utf-8') as file:
            content = file.read()
            content = content.strip()
            data_str = [value for value in content.split(' ') if value]
            data = list(map(float, data_str))
            data = np.array(data)
            return data
    except OSError:
        print("Cannot find file {} in current directory".format(path))
        return []


def normalize(X):
    return (X - torch.mean(X)) / torch.std(X)

def _t2n(x):
    return x.detach().cpu().numpy()


def adjust_opt(optimizer, epoch):
    if epoch == 500:
        lr = 5e-3
    elif epoch == 750:
        lr = 1e-3
    elif epoch == 900:
        lr = 5e-4
    else:
        return
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


def train(train_X, train_Y, file_name):
    X_train, X_test, y_train, y_test = train_test_split(train_X, train_Y, test_size=0.2, shuffle=True)
    train_db = MyDataset(X_train, y_train)
    test_db = MyDataset(X_test, y_test)
    BATCH_SIZE = 32
    train_loader = DataLoader(train_db, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(test_db, batch_size=BATCH_SIZE, shuffle=True)
    model = MLP(3, 1, [20, 10])
    model = model.to(device)
    loss = nn.L1Loss()
    optimizer = optim.SGD(model.parameters(), lr=0.006, momentum=0.9, weight_decay=5e-4)
    num_epochs = 1000
    train_loss_list = []
    train_r2_list = []
    test_loss_list = []
    test_r2_list = []
    max_test_r2 = 0.97
    best_model = {}
    min_test_loss = 0
    for epoch in range(num_epochs):
        train_loss = 0
        pred_all = None
        output_all = None
        model.train()
        adjust_opt(optimizer, epoch)
        for step, data in enumerate(train_loader):
            X, y = data
            X = X.to(device)
            X = X.type(torch.cuda.FloatTensor)
            y = y.to(device)
            y = y.type(torch.cuda.FloatTensor)
            out = model(X)
            loss_value = loss(out, y)
            optimizer.zero_grad()
            loss_value.backward(retain_graph=True)
            optimizer.step()
            train_loss += float(loss_value)
            if output_all is None:
                output_all = y
            else:
                output_all = torch.cat([output_all, y])
            if pred_all is None:
                pred_all = out
            else:
                pred_all = torch.cat([pred_all, out])
        train_loss_list.append(train_loss / len(train_loader))
        train_y = output_all.cpu().detach().numpy()
        train_pred = pred_all.cpu().detach().numpy()
        train_r2 = r2_score(train_y, train_pred)
        train_r2_list.append(train_r2)
        print('epoch', epoch, ':')
        print('train_loss:', train_loss_list[-1])
        print('train_r2:', train_r2_list[-1])
        test_loss = 0
        pred_all = None
        output_all = None
        model.eval()
        for step, data in enumerate(test_loader):
            X, y = data
            X = X.to(device)
            X = X.type(torch.cuda.FloatTensor)
            y = y.to(device)
            y = y.type(torch.cuda.FloatTensor)
            out = model(X)
            loss_value = loss(out, y)
            test_loss += float(loss_value)
            if output_all is None:
                output_all = y
            else:
                output_all = torch.cat([output_all, y])
            if pred_all is None:
                pred_all = out
            else:
                pred_all = torch.cat([pred_all, out])
            torch.cuda.empty_cache()
        test_y = output_all.cpu().detach().numpy()
        test_pred = pred_all.cpu().detach().numpy()
        test_loss_list.append(test_loss / len(test_loader))
        test_r2 = r2_score(test_y, test_pred)
        test_r2_list.append(test_r2)
        print('test_loss:', test_loss_list[-1])
        print('test_r2:', test_r2_list[-1])
        print('max_test_r2:', max_test_r2)
        if test_r2 > max_test_r2:
            best_model = model.state_dict()
            min_test_loss = test_loss_list[-1]
            max_test_r2 = test_r2
    torch.save(best_model, "./model/" + file_name + "-" + str(max_test_r2) + "-" + str(min_test_loss) + ".pth")
    tmp = open("./train_result/" + file_name + "_result_loss" + ".csv", 'w', newline='')
    csv_write = csv.writer(tmp)
    csv_write.writerow(["train_loss", "test_loss", "test_r2"])
    for i in range(len(train_loss_list)):
        csv_write.writerow([train_loss_list[i], test_loss_list[i], test_r2_list[i]])
    tmp.close()


# generate train data
hifi = hifi_F16()
ALPHA1 = safe_read_dat(r'ALPHA1.dat')
ALPHA2 = safe_read_dat(r'ALPHA2.dat')
BETA1 = safe_read_dat(r'BETA1.dat')
DH1 = safe_read_dat(r'DH1.dat')
DH2 = safe_read_dat(r'DH2.dat')
raw_alpha1 = np.linspace(ALPHA1[0], ALPHA1[-1], 30)
raw_beta = np.linspace(BETA1[0], BETA1[-1], 30)
raw_el = np.linspace(DH1[0], DH1[-1], 30)
alpha = np.tile(raw_alpha1.reshape(-1, 1), raw_beta.shape[0] * raw_el.shape[0])
alpha = alpha.reshape(-1)
beta = np.tile(raw_beta.reshape(-1, 1), raw_el.shape[0])
beta = beta.reshape(-1)
beta = np.tile(beta, raw_alpha1.shape[0])
el = np.tile(raw_el, raw_alpha1.shape[0] * raw_beta.shape[0])
alpha = torch.tensor(alpha, device=torch.device(device), requires_grad=True)
beta = torch.tensor(beta, device=torch.device(device), requires_grad=True)
el = torch.tensor(el, device=torch.device(device), requires_grad=True)
Cx = hifi._Cx(alpha, beta, el)
Cz = hifi._Cz(alpha, beta, el)
Cm = hifi._Cm(alpha, beta, el)
Cy = hifi._Cy(alpha, beta)
Cn = hifi._Cn(alpha, beta, el)
Cl = hifi._Cl(alpha, beta, el)

tmp = open("./mean_std.csv", 'w', newline='')
csv_write = csv.writer(tmp)
csv_write.writerow(["name", "alpha_mean", "alpha_std", "beta_mean", "beta_std", "el_mean", "el_std", "mean", "std"])
csv_write.writerow(["Cx", _t2n(torch.mean(alpha)), _t2n(torch.std(alpha)), _t2n(torch.mean(beta)), _t2n(torch.std(beta)), _t2n(torch.mean(el)), _t2n(torch.std(el)), _t2n(torch.mean(Cx)), _t2n(torch.std(Cx))])
csv_write.writerow(["Cz", _t2n(torch.mean(alpha)), _t2n(torch.std(alpha)), _t2n(torch.mean(beta)), _t2n(torch.std(beta)), _t2n(torch.mean(el)), _t2n(torch.std(el)), _t2n(torch.mean(Cz)), _t2n(torch.std(Cz))])
csv_write.writerow(["Cm", _t2n(torch.mean(alpha)), _t2n(torch.std(alpha)), _t2n(torch.mean(beta)), _t2n(torch.std(beta)), _t2n(torch.mean(el)), _t2n(torch.std(el)), _t2n(torch.mean(Cm)), _t2n(torch.std(Cm))])
csv_write.writerow(["Cn", _t2n(torch.mean(alpha)), _t2n(torch.std(alpha)), _t2n(torch.mean(beta)), _t2n(torch.std(beta)), _t2n(torch.mean(el)), _t2n(torch.std(el)), _t2n(torch.mean(Cn)), _t2n(torch.std(Cn))])
csv_write.writerow(["Cl", _t2n(torch.mean(alpha)), _t2n(torch.std(alpha)), _t2n(torch.mean(beta)), _t2n(torch.std(beta)), _t2n(torch.mean(el)), _t2n(torch.std(el)), _t2n(torch.mean(Cl)), _t2n(torch.std(Cl))])
tmp.close()

# normalize data
alpha = normalize(alpha)
beta = normalize(beta)
el = normalize(el)
Cx = normalize(Cx)
Cz = normalize(Cz)
Cm = normalize(Cm)
Cn = normalize(Cn)
Cl = normalize(Cl)

# train model
train_X = torch.hstack((alpha.reshape(-1, 1), beta.reshape(-1, 1)))
train_X = torch.hstack((train_X, el.reshape(-1, 1)))
train_Y = Cx.reshape(-1, 1)
train(train_X=train_X, train_Y=train_Y, file_name="Cx")
train_Y = Cz.reshape(-1, 1)
train(train_X=train_X, train_Y=train_Y, file_name="Cz")
train_Y = Cm.reshape(-1, 1)
train(train_X=train_X, train_Y=train_Y, file_name="Cm")
train_Y = Cn.reshape(-1, 1)
train(train_X=train_X, train_Y=train_Y, file_name="Cn")
train_Y = Cl.reshape(-1, 1)
train(train_X=train_X, train_Y=train_Y, file_name="Cl")

### 3. Build your FDM
Build your GPU-accelerated FDM

In [1]:
import os
import sys
import torch
import torch.nn as nn
sys.path.append(os.path.dirname(os.path.abspath('.')))
from envs.models.F16.hifi_F16_AeroData import hifi_F16


class F16Dynamics(nn.Module):
    def __init__(self, device):
        super().__init__()
        self.hifi_F16 = hifi_F16(device=device)

    def compute_extended_state(self, x):
        return self.nlplant(x)

    def forward(self, t, x):
        es = self.compute_extended_state(x)
        return es
    
    def atmos(self, alt, vt):
        # 根据高度和速度计算动压、马赫数、静压
        rho0 = 2.377e-3
        tfac = 1 - .703e-5 * (alt)
        temp = 519.0 * tfac
        temp = (alt >= 35000.0) * 390 + (alt < 35000.0) * temp
        rho = rho0 * pow(tfac, 4.14)
        mach = (vt) / torch.sqrt(1.4 * 1716.3 * temp)
        qbar = .5 * rho * pow(vt, 2)
        ps = 1715.0 * rho * temp

        ps = (ps == 0) * 1715 + (ps != 0) * ps

        return (mach, qbar, ps)
    
    def nlplant(self, x):
        """
        model state(dim 12):
            0. ego_north_position      (unit: feet)
            1. ego_east_position       (unit: feet)
            2. ego_altitude            (unit: feet)
            3. ego_roll                (unit: rad)
            4. ego_pitch               (unit: rad)
            5. ego_yaw                 (unit: rad)
            6. ego_vt                  (unit: feet/s)
            7. ego_alpha               (unit: rad)
            8. ego_beta                (unit: rad)
            9. ego_P                   (unit: rad/s)
            10. ego_Q                  (unit: rad/s)
            11. ego_R                  (unit: rad/s)

        model control(dim 5)
            0. ego_T                  (unit: lbf)
            1. ego_el                 (unit: deg)
            2. ego_ail                (unit: deg)
            3. ego_rud                (unit: deg)
            4. ego_lef                (unit: deg)
        """
        xdot = torch.zeros_like(x)
        g = 32.17
        m = 636.94
        B = 30.0
        S = 300.0
        cbar = 11.32
        xcgr = 0.35
        xcg = 0.30
        Heng = 0.0
        pi = torch.pi

        Jy = 55814.0
        Jxz = 982.0
        Jz = 63100.0
        Jx = 9496.0

        r2d = 180.0 / pi

        # States
        alt = x[:, 2]
        phi = x[:, 3]
        theta = x[:, 4]
        psi = x[:, 5]

        vt = x[:, 6]
        alpha = x[:, 7] * r2d
        beta = x[:, 8] * r2d
        P = x[:, 9]
        Q = x[:, 10]
        R = x[:, 11]

        sa = torch.sin(x[:, 7])
        ca = torch.cos(x[:, 7])
        sb = torch.sin(x[:, 8])
        cb = torch.cos(x[:, 8])

        st = torch.sin(theta)
        ct = torch.cos(theta)
        tt = torch.tan(theta)
        sphi = torch.sin(phi)
        cphi = torch.cos(phi)
        spsi = torch.sin(psi)
        cpsi = torch.cos(psi)

        vt = (vt <= 0.01) * 0.01 + (vt > 0.01) * vt

        # Control inputs

        T = x[:, 12]
        el = x[:, 13]
        ail = x[:, 14]
        rud = x[:, 15]
        lef = x[:, 16]

        dail = ail / 21.5
        drud = rud / 30.0
        dlef = (1 - lef / 25.0)

        # Atmospheric effects
        # sets dynamic pressure and mach number

        temp = self.atmos(alt, vt)
        mach = temp[0]
        qbar = temp[1] # dynamic pressure
        ps = temp[2]

        # Dynamics
        # Navigation Equations

        U = vt * ca * cb
        V = vt * sb
        W = vt * sa * cb

        xdot[:, 0] = U * (ct * cpsi) + V * (sphi * cpsi * st - cphi * spsi) + W * (cphi * st * cpsi + sphi * spsi)
        xdot[:, 1] = U * (ct * spsi) + V * (sphi * spsi * st + cphi * cpsi) + W * (cphi * st * spsi - sphi * cpsi)
        xdot[:, 2] = U * st - V * (sphi * ct) - W * (cphi * ct)
        xdot[:, 3] = P + tt * (Q * sphi + R * cphi)
        xdot[:, 4] = Q * cphi - R * sphi
        xdot[:, 5] = (Q * sphi + R * cphi) / ct

        temp = self.hifi_F16.hifi_C(alpha, beta, el)
        Cx = temp[0]
        Cz = temp[1]
        Cm = temp[2]
        Cy = temp[3]
        Cn = temp[4]
        Cl = temp[5]

        temp = self.hifi_F16.hifi_damping(alpha)
        Cxq = temp[0]
        Cyr = temp[1]
        Cyp = temp[2]
        Czq = temp[3]
        Clr = temp[4]
        Clp = temp[5]
        Cmq = temp[6]
        Cnr = temp[7]
        Cnp = temp[8]

        temp = self.hifi_F16.hifi_C_lef(alpha, beta)
        delta_Cx_lef = temp[0]
        delta_Cz_lef = temp[1]
        delta_Cm_lef = temp[2]
        delta_Cy_lef = temp[3]
        delta_Cn_lef = temp[4]
        delta_Cl_lef = temp[5]

        temp = self.hifi_F16.hifi_damping_lef(alpha)
        delta_Cxq_lef = temp[0]
        delta_Cyr_lef = temp[1]
        delta_Cyp_lef = temp[2]
        delta_Clr_lef = temp[4]
        delta_Clp_lef = temp[5]
        delta_Cmq_lef = temp[6]
        delta_Cnr_lef = temp[7]
        delta_Cnp_lef = temp[8]

        temp = self.hifi_F16.hifi_rudder(alpha, beta)
        delta_Cy_r30 = temp[0]
        delta_Cn_r30 = temp[1]
        delta_Cl_r30 = temp[2]

        temp = self.hifi_F16.hifi_ailerons(alpha, beta)
        delta_Cy_a20 = temp[0]
        delta_Cy_a20_lef = temp[1]
        delta_Cn_a20 = temp[2]
        delta_Cn_a20_lef = temp[3]
        delta_Cl_a20 = temp[4]
        delta_Cl_a20_lef = temp[5]

        temp = self.hifi_F16.hifi_other_coeffs(alpha, el)
        delta_Cnbeta = temp[0]
        delta_Clbeta = temp[1]
        delta_Cm = temp[2]
        eta_el = temp[3]
        delta_Cm_ds = temp[4]
        
        dXdQ = (cbar / (2 * vt)) * (Cxq + delta_Cxq_lef * dlef)
        Cx_tot = Cx + delta_Cx_lef * dlef + dXdQ * Q
        dZdQ = (cbar / (2 * vt)) * (Czq + delta_Cz_lef * dlef)
        Cz_tot = Cz + delta_Cz_lef * dlef + dZdQ * Q
        dMdQ = (cbar / (2 * vt)) * (Cmq + delta_Cmq_lef * dlef)
        Cm_tot = Cm * eta_el + Cz_tot * (xcgr - xcg) + delta_Cm_lef * dlef + dMdQ * Q + delta_Cm + delta_Cm_ds
        dYdail = delta_Cy_a20 + delta_Cy_a20_lef * dlef
        dYdR = (B / (2 * vt)) * (Cyr + delta_Cyr_lef * dlef)
        dYdP = (B / (2 * vt)) * (Cyp + delta_Cyp_lef * dlef)
        Cy_tot = Cy + delta_Cy_lef * dlef + dYdail * dail + delta_Cy_r30 * drud + dYdR * R + dYdP * P
        dNdail = delta_Cn_a20 + delta_Cn_a20_lef * dlef
        dNdR = (B / (2 * vt)) * (Cnr + delta_Cnr_lef * dlef)
        dNdP = (B / (2 * vt)) * (Cnp + delta_Cnp_lef * dlef)
        Cn_tot = Cn + delta_Cn_lef * dlef - Cy_tot * (xcgr - xcg) * (cbar / B) + dNdail * dail + delta_Cn_r30 * drud + dNdR * R + dNdP * P + delta_Cnbeta * beta
        dLdail = delta_Cl_a20 + delta_Cl_a20_lef * dlef
        dLdR = (B / (2 * vt)) * (Clr + delta_Clr_lef * dlef)
        dLdP = (B / (2 * vt)) * (Clp + delta_Clp_lef * dlef)
        Cl_tot = Cl + delta_Cl_lef * dlef + dLdail * dail + delta_Cl_r30 * drud + dLdR * R + dLdP * P + delta_Clbeta * beta
        Udot = R * V - Q * W - g * st + qbar * S * Cx_tot / m + T / m
        Vdot = P * W - R * U + g * ct * sphi + qbar * S * Cy_tot / m
        Wdot = Q * U - P * V + g * ct * cphi + qbar * S * Cz_tot / m
        xdot[:, 6] = (U * Udot + V * Vdot + W * Wdot) / vt
        xdot[:, 7] = (U * Wdot - W * Udot) / (U * U + W * W)
        xdot[:, 8] = (Vdot * vt - V * xdot[:, 6]) / (vt * vt * cb)
        L_tot = Cl_tot * qbar * S * B
        M_tot = Cm_tot * qbar * S * cbar
        N_tot = Cn_tot * qbar * S * B
        denom = Jx * Jz - Jxz * Jxz
        xdot[:, 9] = (Jz * L_tot + Jxz * N_tot - (Jz * (Jz - Jy) + Jxz * Jxz) * Q * R + Jxz * (Jx - Jy + Jz) * P * Q + Jxz * Q * Heng) / denom
        xdot[:, 10] = (M_tot + (Jz - Jx) * P * R - Jxz * (P * P - R * R) - R * Heng) / Jy
        xdot[:, 11] = (Jx * N_tot + Jxz * L_tot + (Jx * (Jx - Jy) + Jxz * Jxz) * P * Q - Jxz * (Jx - Jy + Jz) * Q * R + Jx * Q * Heng) / denom

        return xdot

### 4. Build your model
Build your own model, more details about your model's interface see in `../envs/models/model_base.py`

In [2]:
import os
import sys
import torch
from torchdiffeq import odeint_adjoint as odeint
sys.path.append(os.path.dirname(os.path.abspath('.')))
from envs.models.model_base import BaseModel


class F16Model(BaseModel):
    def __init__(self, config, n, device, random_seed):
        super().__init__(config, n, device, random_seed)
        self.num_states = getattr(self.config, 'num_states', 12)
        self.num_controls = getattr(self.config, 'num_controls', 5)
        self.dt = getattr(self.config, 'dt', 0.02)
        self.solver = getattr(self.config, 'solver', 'euler')
        self.airspeed = getattr(self.config, 'airspeed', 0)

        self.s = torch.zeros((self.n, self.num_states), device=self.device)  # state
        self.recent_s = torch.zeros((self.n, self.num_states), device=self.device)  # recent state
        self.u = torch.zeros((self.n, self.num_controls), device=self.device) # control
        self.recent_u = torch.zeros((self.n, self.num_controls), device=self.device)  # recent control

        # init parameters
        self.max_altitude = getattr(self.config, 'max_altitude', 20000)
        self.min_altitude = getattr(self.config, 'min_altitude', 19000)
        self.max_vt = getattr(self.config, 'max_vt', 1200)
        self.min_vt = getattr(self.config, 'min_vt', 1000)
        # self.init_state = self.config.init_state

        self.dynamics = F16Dynamics(device)

    def reset(self, env):
        done = env.is_done.bool()
        bad_done = env.bad_done.bool()
        exceed_time_limit = env.exceed_time_limit.bool()
        reset = (done | bad_done) | exceed_time_limit
        size = torch.sum(reset)
        self.s[reset, :] = torch.zeros((size, self.num_states), device=self.device)  # state
        self.u[reset, :] = torch.zeros((size, self.num_controls), device=self.device)
        self.s[reset, 2] = torch.rand_like(self.s[reset, 2]) * (self.max_altitude - self.min_altitude) + self.min_altitude
        self.s[reset, 6] = torch.rand_like(self.s[reset, 6]) * (self.max_vt - self.min_vt) + self.min_vt
        # self.u[reset, 0] = self.init_state['init_T']
        self.recent_s[reset] = self.s[reset]
        self.recent_u[reset] = self.u[reset]

    def get_extended_state(self):
        x = torch.hstack((self.s, self.u))
        return self.dynamics.nlplant(x)
    
    def update(self, action):
        action = torch.clamp(action, -1, 1)
        T = 0.9 * self.u[:, 0].reshape(-1, 1) + 0.1 * action[:, 0].reshape(-1, 1) * 0.225 * 76300 / 0.3048
        el = 0.9 * self.u[:, 1].reshape(-1, 1) + 0.1 * action[:, 1].reshape(-1, 1) * 45
        ail = 0.9 * self.u[:, 2].reshape(-1, 1) + 0.1 * action[:, 2].reshape(-1, 1) * 45
        rud = 0.9 * self.u[:, 3].reshape(-1, 1) + 0.1 * action[:, 3].reshape(-1, 1) * 45
        lef = torch.zeros((self.n, 1), device=self.device)
        self.recent_u = self.u
        self.u = torch.hstack((T, el))
        self.u = torch.hstack((self.u, ail))
        self.u = torch.hstack((self.u, rud))
        self.u = torch.hstack((self.u, lef))
        self.recent_s = self.s
        self.s = odeint(self.dynamics,
                        torch.hstack((self.s, self.u)),
                        torch.tensor([0., self.dt], device=self.device),
                        method=self.solver)[1, :, :self.num_states]
    
    def get_state(self):
        return self.s
    
    def get_control(self):
        return self.u
    
    def get_position(self):
        return self.s[:, 0], self.s[:, 1], self.s[:, 2]
    
    def get_ground_speed(self):
        es = self.get_extended_state()
        return es[:, 0], es[:, 1]
    
    def get_climb_rate(self):
        es = self.get_extended_state()
        return es[:, 2]
    
    def get_posture(self):
        return self.s[:, 3], self.s[:, 4], self.s[:, 5]
    
    def get_euler_angular_velocity(self):
        es = self.get_extended_state()
        return es[:, 3], es[:, 4], es[:, 5]
    
    def get_vt(self):
        return self.s[:, 6]
    
    def get_TAS(self):
        return self.s[:, 6] + self.airspeed * torch.ones_like(self.s[:, 6])
    
    def get_EAS(self):
        TAS = self.get_TAS()
        EAS2TAS = self.get_EAS2TAS()
        EAS = TAS / EAS2TAS
        return EAS
    
    def get_AOA(self):
        return self.s[:, 7]
    
    def get_AOS(self):
        return self.s[:, 8]
    
    def get_angular_velocity(self):
        return self.s[:, 9], self.s[:, 10], self.s[:, 11]
    
    def get_thrust(self):
        return self.u[:, 0]
    
    def get_control_surface(self):
        return self.u[:, 1], self.u[:, 2], self.u[:, 3], self.u[:, 4]

    
    def get_velocity(self):
        # 根据飞行状态计算三轴速度
        sina = torch.sin(self.s[:, 7])
        cosa = torch.cos(self.s[:, 7])
        sinb = torch.sin(self.s[:, 8])
        cosb = torch.cos(self.s[:, 8])
        vel_u = self.s[:, 6] * cosb * cosa # x轴速度
        vel_v = self.s[:, 6] * sinb # y轴速度
        vel_w = self.s[:, 6] * cosb * sina # z轴速度
        return vel_u, vel_v, vel_w
    
    def get_acceleration(self):
        # 根据飞行状态计算三轴加速度
        xdot = self.get_extended_state()
        sina = torch.sin(self.s[:, 7])
        cosa = torch.cos(self.s[:, 7])
        sinb = torch.sin(self.s[:, 8])
        cosb = torch.cos(self.s[:, 8])
        vel_u = self.s[:, 6] * cosb * cosa # x轴速度
        vel_v = self.s[:, 6] * sinb # y轴速度
        vel_w = self.s[:, 6] * cosb * sina # z轴速度
        u_dot = cosb * cosa * xdot[:, 6] - self.s[:, 6] * sinb * cosa * xdot[:, 8] - self.s[:, 6] * cosb * sina * xdot[:, 7]
        v_dot = sinb * xdot[:, 6] + self.s[:, 6] * cosb * xdot[:, 8]
        w_dot = cosb * sina * xdot[:, 6] - self.s[:, 6] * sinb * sina * xdot[:, 8] + self.s[:, 6] * cosb * cosa * xdot[:, 7]
        ax = u_dot + self.s[:, 10] * vel_w - self.s[:, 11] * vel_v
        ay = v_dot + self.s[:, 11] * vel_u - self.s[:, 9] * vel_w
        az = w_dot + self.s[:, 9] * vel_v - self.s[:, 10] * vel_u
        return ax, ay, az
    
    def get_G(self):
        # 根据飞行状态计算过载
        nx_cg, ny_cg, nz_cg = self.get_accels()
        G = torch.sqrt(nx_cg ** 2 + ny_cg ** 2 + nz_cg ** 2)
        return G
    
    def get_EAS2TAS(self):
        # 根据高度计算EAS2TAS
        alt = self.s[:, 2]
        tfac = 1 - .703e-5 * (alt)
        eas2tas = 1 / torch.pow(tfac, 4.14)
        eas2tas = torch.sqrt(eas2tas)
        return eas2tas
    
    def get_accels(self):
        # 根据飞行状态计算三轴过载
        grav = 32.174
        xdot = self.get_extended_state()
        sina = torch.sin(self.s[:, 7])
        cosa = torch.cos(self.s[:, 7])
        sinb = torch.sin(self.s[:, 8])
        cosb = torch.cos(self.s[:, 8])
        vel_u = self.s[:, 6] * cosb * cosa
        vel_v = self.s[:, 6] * sinb
        vel_w = self.s[:, 6] * cosb * sina
        u_dot = cosb * cosa * xdot[:, 6] - self.s[:, 6] * sinb * cosa * xdot[:, 8] - self.s[:, 6] * cosb * sina * xdot[:, 7]
        v_dot = sinb * xdot[:, 6] + self.s[:, 6] * cosb * xdot[:, 8]
        w_dot = cosb * sina * xdot[:, 6] - self.s[:, 6] * sinb * sina * xdot[:, 8] + self.s[:, 6] * cosb * cosa * xdot[:, 7]
        nx_cg = 1.0 / grav * (u_dot + self.s[:, 10] * vel_w - self.s[:, 11] * vel_v) + torch.sin(self.s[:, 4])
        ny_cg = 1.0 / grav * (v_dot + self.s[:, 11] * vel_u - self.s[:, 9] * vel_w) - torch.cos(self.s[:, 4]) * torch.sin(self.s[:, 3])
        nz_cg = -1.0 / grav * (w_dot + self.s[:, 9] * vel_v - self.s[:, 10] * vel_u) + torch.cos(self.s[:, 4]) * torch.cos(self.s[:, 3])
        return nx_cg, ny_cg, nz_cg

    def get_atmos(self):
        # 根据高度和速度计算动压、马赫数、静压
        alt = self.s[:, 2]
        vt = self.s[:, 6]
        rho0 = 2.377e-3
        tfac = 1 - .703e-5 * (alt)
        temp = 519.0 * tfac
        temp = (alt >= 35000.0) * 390 + (alt < 35000.0) * temp
        rho = rho0 * pow(tfac, 4.14)
        mach = (vt) / torch.sqrt(1.4 * 1716.3 * temp)
        qbar = .5 * rho * pow(vt, 2)
        ps = 1715.0 * rho * temp

        ps = (ps == 0) * 1715 + (ps != 0) * ps

        return (mach, qbar, ps)

### 5. Test your model

In [3]:
from tqdm import tqdm


model = F16Model(config=None, n=10000, device=torch.device('cpu'), random_seed=42)
for i in tqdm(range(1000)):
    action = torch.zeros((10000, 5), device=torch.device('cpu'))
    model.update(action)

100%|██████████| 1000/1000 [00:44<00:00, 22.53it/s]
