In [13]:
from model.ResNet_attention import ResidualAttentionModel_92_32input_update
import torch
from util.util import load_data_percentage, SpecDataset, create_data_loaders
import torch.nn as nn 
import torch
import datetime
import os

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [14]:
now = datetime.datetime.now()
current_time = now.strftime("%Y_%m_%d_%H_%M_%S")
X, Y, Z = load_data_percentage('./data/X_toy.npy', './data/Y_toy.npy', './data/Z_toy.npy', percentage=100) 
config = {
    'model_type': 'resnet_attention',
    'current_time': current_time,
    'epoch': 1,
    'batch_size': 32,
    'cuda_device': torch.device("cuda:5")
}

dataset = SpecDataset(X, Y, Z, config)

In [15]:
device = config['cuda_device']
num_epochs = config['epoch']


model = ResidualAttentionModel_92_32input_update()
model.to(device)

best_val_loss = float('inf')
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

os.makedirs('./log', exist_ok=True)
os.makedirs('./weights', exist_ok=True)
criterion = nn.MSELoss()

train_loader, val_loader, test_loader = create_data_loaders(dataset, batch_size=config['batch_size'])

with open(f"./log/{config['current_time']}_{config['model_type']}_{config['epoch']}.txt", "a") as log_file:
    for epoch in range(num_epochs):

        model.train()
        train_loss = 0
        for batch_idx, (x, _, z) in enumerate(train_loader):
            x, z = x.to(device), z.to(device)
            optimizer.zero_grad()
            outputs = model(x)
            loss = criterion(outputs, z)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

            if (batch_idx + 1) % 10 == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}], Step [{batch_idx+1}/{len(train_loader)}], Loss: {loss.item():.4f}')
        
        avg_train_loss = train_loss / len(train_loader)

        model.eval()
        val_loss = 0
        with torch.no_grad():
            for x, _, z in val_loader:
                x, z = x.to(device), z.to(device)
                outputs = model(x)
                loss = criterion(outputs, z)
                val_loss += loss.item()
        
        avg_val_loss = val_loss / len(val_loader)

        # Test
        test_loss = 0
        with torch.no_grad():
            for x, _, z in test_loader:
                x, z = x.to(device), z.to(device)
                outputs = model(x)
                loss = criterion(outputs, z)
                test_loss += loss.item()
        
        avg_test_loss = test_loss / len(test_loader)

        print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Test Loss: {avg_test_loss:.4f}')
        
        # Log the losses of each epoch to the file
        log_file.write(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Test Loss: {avg_test_loss:.4f}\n')
        log_file.flush() 
        
        # Save the best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_model_name = f"./weights/{config['current_time']}_{config['model_type']}_best.pth"
            torch.save(model.state_dict(), best_model_name)
            print(f"Saved best model with validation loss: {best_val_loss:.4f}")

print("Training finished!")

# Save the last model
last_model_name = f"./weights/{config['current_time']}_{config['model_type']}_last.pth"
torch.save(model.state_dict(), last_model_name)
print(f"Saved last model")

# Final evaluation
model.eval()
with torch.no_grad():
    final_test_loss = 0
    for x, _, z in test_loader:
        x, z = x.to(device), z.to(device)
        outputs = model(x)
        loss = criterion(outputs, z)
        final_test_loss += loss.item()
    
    avg_final_test_loss = final_test_loss / len(test_loader)
    print(f"Final Test Loss: {avg_final_test_loss:.4f}")

After conv1: torch.Size([7, 32, 128, 256])
After residual_block1: torch.Size([7, 128, 128, 256])
After reshape: torch.Size([7, 256, 128, 128])
After attention_module1: torch.Size([7, 256, 128, 128])
After residual_block2: torch.Size([7, 256, 16, 16])
After attention_module2: torch.Size([7, 256, 16, 16])
After attention_module2_2: torch.Size([7, 256, 16, 16])
After residual_block3: torch.Size([7, 512, 8, 8])
After attention_module3: torch.Size([7, 512, 8, 8])
After attention_module3_2: torch.Size([7, 512, 8, 8])
After attention_module3_3: torch.Size([7, 512, 8, 8])
After residual_block4: torch.Size([7, 1024, 8, 8])
After residual_block5: torch.Size([7, 1024, 8, 8])
After residual_block6: torch.Size([7, 1024, 8, 8])
After mpool2: torch.Size([7, 1024, 1, 1])
After flatten: torch.Size([7, 1024])
After fully connected layer (fc): torch.Size([7, 2])
After conv1: torch.Size([1, 32, 128, 256])
After residual_block1: torch.Size([1, 128, 128, 256])
After reshape: torch.Size([1, 256, 128, 128])
A

In [19]:
import matplotlib.pyplot as plt
import numpy as np
import os

# Hook function to save feature maps
feature_maps = {}

def save_feature_map(name):
    def hook(model, input, output):
        feature_maps[name] = output.detach().cpu().numpy()
    return hook

# Register hooks on the desired layers
model.conv1.register_forward_hook(save_feature_map('conv1'))
model.residual_block1.register_forward_hook(save_feature_map('residual_block1'))

# Visualization function
def plot_feature_maps(feature_map, layer_name, epoch, batch_idx):
    num_channels = feature_map.shape[1]
    rows = int(np.sqrt(num_channels))  # Grid dimensions
    cols = int(np.ceil(num_channels / rows))
    
    fig, axes = plt.subplots(rows, cols, figsize=(15, 15))
    fig.suptitle(f'{layer_name} Feature Maps - Epoch {epoch+1}, Batch {batch_idx+1}', fontsize=16)

    for i in range(num_channels):
        row = i // cols
        col = i % cols
        axes[row, col].imshow(feature_map[0, i, :, :], cmap='viridis')
        axes[row, col].axis('off')

    plt.tight_layout()
    plt.subplots_adjust(top=0.9)
    save_dir = f"./feature_maps/{layer_name}/epoch_{epoch+1}"
    os.makedirs(save_dir, exist_ok=True)
    plt.savefig(f"{save_dir}/batch_{batch_idx+1}.png")
    plt.close()

# Inside training loop after every batch
for batch_idx, (x, _, z) in enumerate(train_loader):
    x, z = x.to(device), z.to(device)
    optimizer.zero_grad()
    outputs = model(x)
    loss = criterion(outputs, z)
    loss.backward()
    optimizer.step()

    train_loss += loss.item()
    
    # Plot and save feature maps for each batch
    if batch_idx % 10 == 0:  # Save every 10 batches
        for layer_name, fmap in feature_maps.items():
            plot_feature_maps(fmap, layer_name, epoch, batch_idx)

    if (batch_idx + 1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Step [{batch_idx+1}/{len(train_loader)}], Loss: {loss.item():.4f}')


After conv1: torch.Size([7, 32, 128, 256])
After residual_block1: torch.Size([7, 128, 128, 256])
After reshape: torch.Size([7, 256, 128, 128])
After attention_module1: torch.Size([7, 256, 128, 128])
After residual_block2: torch.Size([7, 256, 16, 16])
After attention_module2: torch.Size([7, 256, 16, 16])
After attention_module2_2: torch.Size([7, 256, 16, 16])
After residual_block3: torch.Size([7, 512, 8, 8])
After attention_module3: torch.Size([7, 512, 8, 8])
After attention_module3_2: torch.Size([7, 512, 8, 8])
After attention_module3_3: torch.Size([7, 512, 8, 8])
After residual_block4: torch.Size([7, 1024, 8, 8])
After residual_block5: torch.Size([7, 1024, 8, 8])
After residual_block6: torch.Size([7, 1024, 8, 8])
After mpool2: torch.Size([7, 1024, 1, 1])
After flatten: torch.Size([7, 1024])
After fully connected layer (fc): torch.Size([7, 2])
