In [38]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader,Dataset
from torchtnt.utils.tqdm import create_progress_bar
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_absolute_error, r2_score
from tqdm import tqdm

In [2]:
torch.manual_seed(67)

<torch._C.Generator at 0x2d6953f86d0>

In [3]:
check_gpu = torch.cuda.is_available()
device = torch.device("cpu")

if check_gpu:
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
    
print(f'Using device: {device}')

Using device: cuda


In [4]:
df = pd.read_csv('../data/wave_packet_spread.csv')

In [5]:
df.head()

Unnamed: 0,h_bar,mass,time,sig_0,sig_0_2,sig_t
0,1,1,4.756226,9.381908,88.020205,9.385332
1,1,1,0.680782,0.188217,0.035426,1.818267
2,1,1,5.368285,5.265139,27.72169,5.289762
3,1,1,1.418382,3.975479,15.804436,3.97948
4,1,1,7.132206,9.257078,85.693495,9.26509


In [6]:
df.shape

(50000, 6)

##### <b>Wavepacket Spreading</b>

$$
\sigma(t) = \sigma_0 \, \sqrt{1 + \left(\frac{\hbar t}{2 m \sigma_0^2}\right)^2}
$$

The characteristic spreading time is defined as:

$$
t_c = \frac{2 m \sigma_0^2}{\hbar}
$$


##### Time-Dependent Wavepacket Width


Using the normalized time $\tfrac{t}{t_c}$, the wavepacket width becomes:

$$
\sigma(t) = \sigma_0 \, \sqrt{1 + \left(\frac{t}{t_c}\right)^2 }
$$

In [7]:
df['t_c'] =  (2 * df['mass'] * df['sig_0_2']) / df['h_bar'] # i know h bar's one gng 😭😂✌️

In [8]:
df['norm_time'] = df['time'] / df['t_c']

In [9]:
df['spreading_factor'] = 1 + np.square(df['norm_time'])

In [10]:
df

Unnamed: 0,h_bar,mass,time,sig_0,sig_0_2,sig_t,t_c,norm_time,spreading_factor
0,1,1,4.756226,9.381908,88.020205,9.385332,176.040411,0.027018,1.000730
1,1,1,0.680782,0.188217,0.035426,1.818267,0.070852,9.608565,93.324518
2,1,1,5.368285,5.265139,27.721690,5.289762,55.443380,0.096825,1.009375
3,1,1,1.418382,3.975479,15.804436,3.979480,31.608872,0.044873,1.002014
4,1,1,7.132206,9.257078,85.693495,9.265090,171.386989,0.041615,1.001732
...,...,...,...,...,...,...,...,...,...
49995,1,1,3.329266,1.077968,1.162014,1.883260,2.324029,1.432541,3.052172
49996,1,1,1.558871,5.098636,25.996091,5.100927,51.992182,0.029983,1.000899
49997,1,1,5.244002,5.927721,35.137876,5.944201,70.275751,0.074620,1.005568
49998,1,1,6.712440,3.603203,12.983073,3.721650,25.966145,0.258507,1.066826


In [11]:
features = df.columns.tolist()
features

['h_bar',
 'mass',
 'time',
 'sig_0',
 'sig_0_2',
 'sig_t',
 't_c',
 'norm_time',
 'spreading_factor']

In [12]:
y_feat = features.pop(features.index('sig_t'))
x_feat = features

In [13]:
print(f'X_features: {x_feat}, y_features: { y_feat}')

X_features: ['h_bar', 'mass', 'time', 'sig_0', 'sig_0_2', 't_c', 'norm_time', 'spreading_factor'], y_features: sig_t


In [14]:
X = df[x_feat].values
y = df[y_feat].values

In [15]:
X

array([[1.00000000e+00, 1.00000000e+00, 4.75622563e+00, ...,
        1.76040411e+02, 2.70178058e-02, 1.00072996e+00],
       [1.00000000e+00, 1.00000000e+00, 6.80781711e-01, ...,
        7.08515501e-02, 9.60856482e+00, 9.33245179e+01],
       [1.00000000e+00, 1.00000000e+00, 5.36828490e+00, ...,
        5.54433802e+01, 9.68246322e-02, 1.00937501e+00],
       ...,
       [1.00000000e+00, 1.00000000e+00, 5.24400229e+00, ...,
        7.02757515e+01, 7.46203658e-02, 1.00556820e+00],
       [1.00000000e+00, 1.00000000e+00, 6.71243982e+00, ...,
        2.59661454e+01, 2.58507365e-01, 1.06682606e+00],
       [1.00000000e+00, 1.00000000e+00, 7.90266995e+00, ...,
        1.98226178e+02, 3.98669340e-02, 1.00158937e+00]])

In [16]:
y

array([9.38533201, 1.81826659, 5.28976192, ..., 5.94420141, 3.72165021,
       9.96346408])

In [17]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=67)

In [18]:
X_train

array([[1.00000000e+00, 1.00000000e+00, 2.09499558e+00, ...,
        8.26443305e+01, 2.53495377e-02, 1.00064260e+00],
       [1.00000000e+00, 1.00000000e+00, 6.41435828e+00, ...,
        3.07784969e+00, 2.08403884e+00, 5.34321787e+00],
       [1.00000000e+00, 1.00000000e+00, 9.36743382e+00, ...,
        2.80982202e-01, 3.33381750e+01, 1.11243391e+03],
       ...,
       [1.00000000e+00, 1.00000000e+00, 8.39385446e+00, ...,
        4.46452458e+01, 1.88012280e-01, 1.03534862e+00],
       [1.00000000e+00, 1.00000000e+00, 6.35880252e+00, ...,
        1.31900861e+02, 4.82089537e-02, 1.00232410e+00],
       [1.00000000e+00, 1.00000000e+00, 5.89285185e-01, ...,
        6.32138596e+01, 9.32208836e-03, 1.00008690e+00]])

In [19]:
y_train

array([ 6.43029695,  2.86754437, 12.50148252, ...,  4.80746261,
        8.13041858,  5.62224835])

In [20]:
X_train, X_dev, y_train, y_dev = train_test_split(X_train, y_train, test_size=0.25, random_state=67)

In [21]:
X_train.shape

(30000, 8)

In [22]:
X_dev.shape

(10000, 8)

In [23]:
X_scaler = StandardScaler()

X_train = X_scaler.fit_transform(X_train)
X_dev = X_scaler.transform(X_dev)
X_test = X_scaler.transform(X_test)

In [24]:
y_train = y_train.reshape(-1, 1)  
y_dev   = y_dev.reshape(-1, 1)    
y_test  = y_test.reshape(-1, 1)   

print(f'''
      
      y_train shape: {y_train.shape}
      y_dev shape: {y_dev.shape}
      y_test shspe: {y_test.shape}
      
      ''')



      y_train shape: (30000, 1)
      y_dev shape: (10000, 1)
      y_test shspe: (10000, 1)

      


In [25]:
y_scaler = StandardScaler()

y_train = y_scaler.fit_transform(y_train)
y_dev = y_scaler.transform(y_dev)
y_test = y_scaler.transform(y_test)

In [26]:
class QWaveSet(Dataset):
    
    def __init__(self, features, labels):
        
        self.features = torch.tensor(features, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.float32)
        
    def __len__(self):
        return len(self.features)
    
    def __getitem__(self, index):
        return self.features[index], self.labels[index]

In [27]:
data_train = QWaveSet(X_train, y_train)
data_dev = QWaveSet(X_dev, y_dev)
data_test = QWaveSet(X_test, y_test)

In [28]:
train_loader = DataLoader(data_train, batch_size=32, shuffle=True)
dev_loader = DataLoader(data_dev, batch_size=32, shuffle=True)
test_loader = DataLoader(data_test, batch_size=32, shuffle=False)

In [29]:
f'train loader size : {len(train_loader)}' , f'dev loader size: {len(dev_loader)}'

('train loader size : 938', 'dev loader size: 313')

In [30]:
class QWaveModel(nn.Module):
    def __init__(self, input_size):
        super().__init__()
        self.model = nn.Sequential(
            
            nn.Linear(input_size, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            
            nn.Linear(64, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            
            nn.Linear(64,1)
            
        )
        
    def forward(self, X):
        return self.model(X)

In [31]:
learning_rate = 1e-3
weight_decay = 1e-4
batch_size = 32
epochs = 250

In [32]:
model = QWaveModel(X_train.shape[1])

model.to(device)

loss_fn = nn.MSELoss()

optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

In [33]:
num_steps_per_epoch = len(train_loader)

for epoch in range(epochs):

    pbar = create_progress_bar(
        dataloader=train_loader,
        desc=f"Epoch [{epoch+1}/{epochs}]",
        num_epochs_completed=epoch,
        num_steps_completed=0,
        max_steps=None,
        max_steps_per_epoch=num_steps_per_epoch,
    )

    for batch_idx, (batch_features, y_true) in enumerate(train_loader):
        batch_features, y_true = batch_features.to(device), y_true.to(device)

        y_pred = model(batch_features)
        loss = loss_fn(y_pred, y_true)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        pbar.update(1)
        pbar.set_postfix(loss=loss.item())

    pbar.close() 


Epoch [1/250] 0:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [2/250] 1:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [3/250] 2:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [4/250] 3:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [5/250] 4:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [6/250] 5:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [7/250] 6:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [8/250] 7:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [9/250] 8:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [10/250] 9:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [11/250] 10:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [12/250] 11:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [13/250] 12:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [14/250] 13:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [15/250] 14:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [16/250] 15:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [17/250] 16:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [18/250] 17:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [19/250] 18:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [20/250] 19:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [21/250] 20:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [22/250] 21:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [23/250] 22:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [24/250] 23:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [25/250] 24:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [26/250] 25:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [27/250] 26:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [28/250] 27:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [29/250] 28:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [30/250] 29:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [31/250] 30:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [32/250] 31:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [33/250] 32:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [34/250] 33:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [35/250] 34:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [36/250] 35:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [37/250] 36:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [38/250] 37:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [39/250] 38:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [40/250] 39:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [41/250] 40:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [42/250] 41:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [43/250] 42:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [44/250] 43:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [45/250] 44:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [46/250] 45:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [47/250] 46:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [48/250] 47:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [49/250] 48:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [50/250] 49:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [51/250] 50:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [52/250] 51:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [53/250] 52:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [54/250] 53:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [55/250] 54:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [56/250] 55:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [57/250] 56:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [58/250] 57:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [59/250] 58:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [60/250] 59:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [61/250] 60:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [62/250] 61:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [63/250] 62:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [64/250] 63:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [65/250] 64:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [66/250] 65:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [67/250] 66:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [68/250] 67:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [69/250] 68:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [70/250] 69:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [71/250] 70:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [72/250] 71:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [73/250] 72:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [74/250] 73:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [75/250] 74:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [76/250] 75:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [77/250] 76:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [78/250] 77:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [79/250] 78:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [80/250] 79:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [81/250] 80:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [82/250] 81:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [83/250] 82:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [84/250] 83:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [85/250] 84:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [86/250] 85:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [87/250] 86:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [88/250] 87:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [89/250] 88:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [90/250] 89:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [91/250] 90:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [92/250] 91:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [93/250] 92:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [94/250] 93:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [95/250] 94:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [96/250] 95:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [97/250] 96:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [98/250] 97:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [99/250] 98:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [100/250] 99:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [101/250] 100:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [102/250] 101:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [103/250] 102:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [104/250] 103:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [105/250] 104:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [106/250] 105:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [107/250] 106:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [108/250] 107:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [109/250] 108:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [110/250] 109:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [111/250] 110:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [112/250] 111:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [113/250] 112:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [114/250] 113:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [115/250] 114:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [116/250] 115:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [117/250] 116:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [118/250] 117:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [119/250] 118:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [120/250] 119:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [121/250] 120:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [122/250] 121:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [123/250] 122:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [124/250] 123:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [125/250] 124:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [126/250] 125:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [127/250] 126:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [128/250] 127:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [129/250] 128:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [130/250] 129:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [131/250] 130:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [132/250] 131:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [133/250] 132:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [134/250] 133:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [135/250] 134:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [136/250] 135:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [137/250] 136:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [138/250] 137:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [139/250] 138:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [140/250] 139:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [141/250] 140:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [142/250] 141:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [143/250] 142:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [144/250] 143:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [145/250] 144:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [146/250] 145:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [147/250] 146:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [148/250] 147:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [149/250] 148:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [150/250] 149:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [151/250] 150:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [152/250] 151:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [153/250] 152:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [154/250] 153:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [155/250] 154:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [156/250] 155:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [157/250] 156:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [158/250] 157:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [159/250] 158:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [160/250] 159:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [161/250] 160:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [162/250] 161:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [163/250] 162:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [164/250] 163:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [165/250] 164:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [166/250] 165:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [167/250] 166:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [168/250] 167:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [169/250] 168:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [170/250] 169:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [171/250] 170:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [172/250] 171:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [173/250] 172:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [174/250] 173:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [175/250] 174:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [176/250] 175:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [177/250] 176:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [178/250] 177:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [179/250] 178:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [180/250] 179:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [181/250] 180:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [182/250] 181:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [183/250] 182:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [184/250] 183:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [185/250] 184:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [186/250] 185:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [187/250] 186:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [188/250] 187:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [189/250] 188:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [190/250] 189:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [191/250] 190:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [192/250] 191:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [193/250] 192:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [194/250] 193:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [195/250] 194:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [196/250] 195:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [197/250] 196:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [198/250] 197:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [199/250] 198:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [200/250] 199:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [201/250] 200:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [202/250] 201:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [203/250] 202:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [204/250] 203:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [205/250] 204:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [206/250] 205:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [207/250] 206:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [208/250] 207:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [209/250] 208:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [210/250] 209:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [211/250] 210:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [212/250] 211:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [213/250] 212:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [214/250] 213:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [215/250] 214:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [216/250] 215:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [217/250] 216:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [218/250] 217:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [219/250] 218:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [220/250] 219:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [221/250] 220:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [222/250] 221:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [223/250] 222:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [224/250] 223:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [225/250] 224:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [226/250] 225:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [227/250] 226:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [228/250] 227:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [229/250] 228:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [230/250] 229:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [231/250] 230:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [232/250] 231:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [233/250] 232:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [234/250] 233:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [235/250] 234:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [236/250] 235:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [237/250] 236:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [238/250] 237:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [239/250] 238:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [240/250] 239:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [241/250] 240:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [242/250] 241:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [243/250] 242:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [244/250] 243:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [245/250] 244:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [246/250] 245:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [247/250] 246:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [248/250] 247:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [249/250] 248:   0%|          | 0/938 [00:00<?, ?it/s]


Epoch [250/250] 249:   0%|          | 0/938 [00:00<?, ?it/s]


In [34]:
model.eval

<bound method Module.eval of QWaveModel(
  (model): Sequential(
    (0): Linear(in_features=8, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=128, out_features=64, bias=True)
    (4): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): Linear(in_features=64, out_features=64, bias=True)
    (7): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU()
    (9): Linear(in_features=64, out_features=1, bias=True)
  )
)>

In [39]:
with torch.no_grad():
    y_pred = []
    y_true_list = []
    
    for batch_features, y_true in test_loader:
        batch_features, y_true = batch_features.to(device), y_true.to(device)
        outputs = model(batch_features)
        
        y_pred.append(outputs.cpu())
        y_true_list.append(y_true.cpu())

    y_pred = torch.cat(y_pred).numpy().flatten()
    y_true = torch.cat(y_true_list).numpy().flatten()


mae = mean_absolute_error(y_true, y_pred)
r2 = r2_score(y_true, y_pred)

print(f"MAE: {mae:.4f}")
print(f"R² Score: {r2:.4f}")

MAE: 0.0248
R² Score: 0.9982


In [None]:
torch.save(model.state_dict(), "WavePacketProp_model.pth")