In [None]:
import torch
import torch.nn as nn
import time
from tqdm import tqdm
!pip install livelossplot
from livelossplot import PlotLosses
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def train(model, train_loader, val_loader, epochs=10, model_dir='.', timing_dir = '.'):
    start_time = time.time()  # Timing the training time
    # Initializing the optimizer for the model parameters
    optim = torch.optim.Adam(model.parameters(), lr=0.001)
    liveloss = PlotLosses()  # Object to track validation and training losses across epochs
    # alpha = 0.5  # parameter to weigh the L1 and dice losses
    l1_loss = nn.L1Loss()
    for epoch in range(epochs):
        logs = {}
        train_loss = 0.0
        val_loss = 0.0

        # Training loop
        for batch_input, batch_target in tqdm(train_loader):
            batch_input = batch_input.to(device)
            batch_target = batch_target.to(device)
            optim.zero_grad()  # resetting gradients
            batch_output = model(batch_input)  # generating images
            # loss = alpha * dice_loss(batch_output, batch_target) + (1 - alpha) * RE_loss(batch_output, batch_target)
            loss = l1_loss(batch_output, batch_target)
            loss.backward()  # backprop
            optim.step()
            train_loss += loss.item()

        # Validation loop
        with torch.no_grad():
            for batch_input, batch_target in tqdm(val_loader):
                batch_input = batch_input.to(device)
                batch_target = batch_target.to(device)
                batch_output = model(batch_input)
                # loss = alpha * dice_loss(batch_output, batch_target) + (1 - alpha) * RE_loss(batch_output, batch_target)
                loss = l1_loss(batch_output, batch_target)
                val_loss += loss.item()

        # Calculate average losses (to make it independent of batch size)
        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)

        # Log the losses for plotting
        logs['log loss'] = avg_train_loss
        logs['val_log loss'] = avg_val_loss

        liveloss.update(logs)
        liveloss.send()
    # End and save timing
    end_time = time.time()
    elapsed_time = end_time - start_time
    print(f'Training time: {elapsed_time} seconds')
    # Save to file
        file.write(f'Training time: {elapsed_time} seconds')

    torch.save(model, model_dir)
    return model