# Imports

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from deeplearning.train_val_test_loop import train_val_loop, test_loop
from deeplearning.callback.callback_list import CallbackList
from deeplearning.callback._logging_callback import _LoggingCallback
from deeplearning.callback.progressbar_callback import ProgressBarCallback

# Model 

In [None]:
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(10, 1)

    def forward(self, x):
        return self.fc(x)

# Data

## Train-Data

In [None]:
x_train = torch.randn(1_000_000, 10)
y_train = torch.randn(1_000_000, 1)

In [None]:
train_dataset = TensorDataset(x_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=100, shuffle=True)

## Validation-Data

In [None]:
x_val = torch.randn(1_000_000, 10)
y_val = torch.randn(1_000_000, 1)

In [None]:
val_dataset = TensorDataset(x_val, y_val)
val_loader = DataLoader(val_dataset, batch_size=100)

## Test-Data

In [None]:
x_test = torch.randn(1_000_000, 10)
y_test = torch.randn(1_000_000, 1)

In [None]:
test_dataset = TensorDataset(x_test, y_test)
test_loader = DataLoader(test_dataset, batch_size=100)

# Criterion, Optimizer, Device etc.

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleModel().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
num_epochs: int = 100_00
use_mixed_precision: bool = False
callbacks = CallbackList(
    [
        # _LoggingCallback(), 
        ProgressBarCallback()
    ]
)

# Train/Val-Loop

In [None]:
train_val_loop(
    train_dataloader=train_loader,
    validation_dataloader=val_loader,
    device=device,
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    callbacks=callbacks,
    num_epochs=num_epochs,
    use_mixed_precision=False,
)

# Test-Loop

In [None]:
test_loop(
    testloader=test_loader,
    model=model,
    device=device,
    callbacks=callbacks,
)