In [None]:
cd ../../

In [None]:
from src.data import get_data_loaders
from src.models.resnet import get_resnet
from src.pruning.slth.edgepopup import modify_module_for_slth
import copy
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR

# settings
num_epochs=100
learning_rate=0.1
num_epochs=100
weight_decay=0.0001
seeds=5
momentum=0.9
batch_size=128

device='cuda'
resnet = get_resnet("ResNet18", 10).to(device)
resnet_slth = modify_module_for_slth(
    resnet, remain_rate=0.3, is_print=False
).to(device)
resnet_slth_init = copy.deepcopy(resnet_slth).to(device)

train_loader, test_loader = get_data_loaders(
    dataset_name="CIFAR10", batch_size=128
)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(
    resnet_slth.parameters(),
    lr=learning_rate,
    momentum=momentum,
    weight_decay=weight_decay,
)

# 学習率のスケジューラ（コサインアニーリング）
scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs)

"""
# Train the model
losses = []
val_accuracies = []

total_step = len(train_loader)
for epoch in range(num_epochs):
    resnet_slth.train()
    epoch_losses = []
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = resnet_slth(images)
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_losses.append(loss.item())

            

    # 学習率の更新
    scheduler.step()
    epoch_loss = sum(epoch_losses) / len(epoch_losses)
    losses.append(epoch_loss)

    
    # エポックごとにテストデータでモデルを評価
    resnet_slth.eval()  # 評価モードに設定
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = resnet_slth(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        acc = 100 * correct / total
        print(f"{epoch}: ACC {acc}")
        val_accuracies.append(acc)
"""

In [None]:
losses = []
val_accuracies = []

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class QNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, action_dim)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

class DQNAgent:
    def __init__(self, state_dim, action_dim, lr=0.001, device='cuda'):
        self.device = device
        self.q_network = QNetwork(state_dim, action_dim).to(device)
        self.target_network = QNetwork(state_dim, action_dim).to(device)
        self.target_network.load_state_dict(self.q_network.state_dict())
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=lr)
        self.gamma = 0.99
        self.epsilon = 0.1
        self.action_dim = action_dim
        
    def select_action(self, state):
        if np.random.rand() < self.epsilon:
            return torch.randn(self.action_dim).to(self.device)
        state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        with torch.no_grad():
            q_values = self.q_network(state)
            return torch.tanh(q_values).squeeze(0)

    def update(self, state, action, reward, next_state, done):
        state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
        next_state = torch.FloatTensor(next_state).unsqueeze(0).to(self.device)
        
        if not isinstance(action, torch.Tensor):
            action = torch.FloatTensor(action)
        action = action.to(self.device)
        
        reward = torch.FloatTensor([reward]).to(self.device)
        done = torch.FloatTensor([float(done)]).to(self.device)

        q_values = self.q_network(state)
        next_q_values = self.target_network(next_state)
        
        # ここを修正
        q_value = (q_values * action).sum(dim=1)
        next_q_value = next_q_values.max(1)[0]
        expected_q_value = reward + self.gamma * next_q_value * (1 - done)

        loss = F.mse_loss(q_value, expected_q_value.detach())

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def update_target_network(self):
        self.target_network.load_state_dict(self.q_network.state_dict())

 

def is_target_layer(name):
    target_layers = ['fc.scores']
    return any(layer in name for layer in target_layers)

def get_all_scores(model):
    scores = []
    for name, param in model.named_parameters():
        if 'scores' in name and is_target_layer(name):
            scores.append(param.data.view(-1).cpu())  # Convert to CPU
    return torch.cat(scores)

def update_scores(model, score_changes, learning_rate=0.01):
    idx = 0
    for name, param in model.named_parameters():
        if 'scores' in name and is_target_layer(name):
            num_params = param.numel()
            param.data += learning_rate * score_changes[idx:idx+num_params].to(param.device).view(param.shape)
            idx += num_params

def get_target_params(model):
    return sum(p.numel() for n, p in model.named_parameters() if 'scores' in n and is_target_layer(n))

def freeze_non_target_scores(model):
    for name, param in model.named_parameters():
        if 'scores' in name and not is_target_layer(name):
            param.requires_grad = False

freeze_non_target_scores(resnet_slth)
total_params = get_target_params(resnet_slth)
dqn_agent = DQNAgent(state_dim=total_params, action_dim=total_params, device=device)

def compute_accuracy(outputs, labels):
    _, predicted = torch.max(outputs.data, 1)
    correct = (predicted == labels).sum().item()
    total = labels.size(0)
    return correct / total



for epoch in range(num_epochs):
    resnet_slth.train()
    epoch_losses = []
    
    state = get_all_scores(resnet_slth)
    
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        # Select action using DQN agent
        score_changes = dqn_agent.select_action(state)

        # Update scores in the model
        update_scores(resnet_slth, score_changes)
        
        # Forward pass
        outputs = resnet_slth(images)
        loss = criterion(outputs, labels)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_losses.append(loss.item())
        accuracy = compute_accuracy(outputs, labels)
        
        # Get new state
        next_state = get_all_scores(resnet_slth)
        
        # Update DQN agent
        dqn_agent.update(state, score_changes, accuracy, next_state, False)
        
        state = next_state

        if (i+1) % 100 == 0:  # Update target network periodically
            dqn_agent.update_target_network()

    # 学習率の更新
    scheduler.step()
    epoch_loss = sum(epoch_losses) / len(epoch_losses)
    losses.append(epoch_loss)

    # エポックごとにテストデータでモデルを評価
    resnet_slth.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = resnet_slth(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        acc = 100 * correct / total
        print(f"Epoch {epoch}: Loss {epoch_loss:.4f}, ACC {acc:.2f}%")
        val_accuracies.append(acc)