In [None]:
from model.ResNet34_attention import ResidualAttentionModel_92_32input_update
import torch
from util.util import load_data_percentage
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn 
import torch

In [None]:
device = torch.device("cuda:5" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load data
X, Y, Z = load_data_percentage('./data/X.npy', './data/Y.npy', './data/Z.npy', percentage=100)
# X is of shape (18480, 32768, 2)
# Y is of shape (18480, 32768, 1)
# Z is of shape (18480, 1, 2)

class CustomDataset(Dataset):
    def __init__(self, X, Y, Z):
        self.X = torch.tensor(X, dtype=torch.float32).permute(0, 2, 1)
        batch_size, dim, length = self.X.shape[0], self.X.shape[1], self.X.shape[2]
        self.X = self.X.reshape(batch_size, dim, int(length / 256), int(length / 128))
        self.Y = torch.tensor(Y, dtype=torch.float32)
        self.Z = torch.tensor(Z, dtype=torch.float32).squeeze(-2)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx], self.Z[idx]

dataset = CustomDataset(X, Y, Z)

In [None]:
# Set device
device = torch.device("cuda:5" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


# Create DataLoader
batch_size = 32
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Initialize model, loss, and optimizer
model = ResidualAttentionModel_92_32input_update().to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

# Training loop
num_epochs = 100
with open("log/loss_resnet34+attention.txt", "a") as log_file:  # Open log file in append mode
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for batch_idx, (x, _, z) in enumerate(train_loader):
            x, z = x.to(device), z.to(device)
            optimizer.zero_grad()
            outputs = model(x)
            loss = criterion(outputs, z)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            if (batch_idx + 1) % 10 == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}], Step [{batch_idx+1}/{len(train_loader)}], Loss: {loss.item():.4f}')
        
        avg_loss = total_loss / len(train_loader)
        print(f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}')
        
        # Log the average loss of each epoch to the file
        log_file.write(f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}\n')

print("Training finished!")

# Simple evaluation
model.eval()
with torch.no_grad():
    x, _, z = next(iter(train_loader))
    x, z = x.to(device), z.to(device)
    outputs = model(x)
    eval_loss = criterion(outputs, z)
    print(f"Evaluation Loss: {eval_loss.item():.4f}")

# Save model weights
model_name = f"./weights/resnet34+attention_{epoch}.pth"
torch.save(model.state_dict(), model_name)