import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from cellmap_segmentation_challenge.utils import get_dataloader
from cellmap_segmentation_challenge.utils.loss import CellMapLossWrapper


class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        self.conv1 = nn.Conv3d(1, 8, 3, 1, 1)
        self.conv2 = nn.Conv3d(8, 1, 3, 1, 1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        return x


def main():
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Initialize model and loss
    model = Network().to(device)
    loss_fn = CellMapLossWrapper(nn.BCEWithLogitsLoss)

    # Initialize optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.001, fused=False)

    # Initialize TensorBoard writer
    writer = SummaryWriter()

    # Get data loaders
    request_array_info = {
        "shape": (128, 128, 128),
        "scale": (8, 8, 8),
    }
    iterations_per_epoch = 200
    train_loader, val_loader = get_dataloader(
        datasplit_path='datasplit.csv',
        classes=["ecs"],
        batch_size=2,
        input_array_info=request_array_info,
        target_array_info=request_array_info,
        iterations_per_epoch=iterations_per_epoch,
        device='cpu',
        num_workers=4,
        pin_memory=True,
        persistent_workers=False,
    )

    # Training parameters
    max_epochs = 2
    global_step = 0

    # Mixed precision setup
    scaler = torch.amp.GradScaler() if device.type == 'cuda' else None
    dtype = torch.float16 if device.type == 'cuda' else torch.float32

    for epoch in range(max_epochs):
        # Training phase
        model.train()
        train_loader.refresh()
        loader = iter(train_loader.loader)
        epoch_bar = tqdm(
            range(iterations_per_epoch), desc="Training", dynamic_ncols=True
        )
        optimizer.zero_grad()
        for epoch_iter in epoch_bar:
            batch = next(loader)
            optimizer.zero_grad()
            global_step += 1

            # Move data to device
            inputs = batch['input'].to(device, non_blocking=True)
            targets = batch['output'].to(device, non_blocking=True)

            # Forward pass with mixed precision
            with torch.autocast(device_type=device.type, dtype=dtype):
                outputs = model(inputs)
                loss = loss_fn(outputs, targets)

            # Backward pass
            if scaler is not None:
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                optimizer.step()

            # Log training loss
            writer.add_scalar("Train/Loss", loss.detach(), global_step)
            global_step += 1

        # Validation phase
        val_loader.refresh()
        model.eval()
        val_bar = tqdm(
            val_loader.loader,
            desc="Validation",
            total=len(val_loader.loader),
            dynamic_ncols=True,
        )
        val_metrics = []
        with torch.no_grad():
            for batch in val_bar:
                inputs = batch['input'].to(device, non_blocking=True)
                targets = batch['output'].to(device, non_blocking=True)

                with torch.autocast(device_type=device.type, dtype=dtype):
                    outputs = model(inputs)
                    loss = loss_fn(outputs, targets)

                val_metrics.append(loss.detach())

        # Log validation loss
        epoch_loss = torch.stack(val_metrics).mean()
        writer.add_scalar("Val/Loss", epoch_loss, epoch)

        # Log VRAM usage (CUDA only)
        if device.type == 'cuda':
            vram_data = torch.cuda.mem_get_info()
            vram_usage = (vram_data[1] - vram_data[0]) / (1024 ** 2)
            writer.add_scalar("Other/VRAM Usage (MB)", vram_usage, epoch)
            torch.cuda.reset_peak_memory_stats()

    writer.close()


if __name__ == "__main__":
    main()