In [None]:
cd ../../

In [None]:
from src.data import get_data_loaders
from src.models.resnet import get_resnet
from src.pruning.slth.edgepopup import modify_module_for_slth
import copy
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR

# settings
num_epochs=100
learning_rate=0.1
num_epochs=100
weight_decay=0.0001
seeds=5
momentum=0.9
batch_size=128

device='cuda'
resnet = get_resnet("ResNet18", 10).to(device)
resnet_slth = modify_module_for_slth(
    resnet, remain_rate=0.3, is_print=False
).to(device)
resnet_slth_init = copy.deepcopy(resnet_slth).to(device)

train_loader, test_loader = get_data_loaders(
    dataset_name="CIFAR10", batch_size=128
)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(
    resnet_slth.parameters(),
    lr=learning_rate,
    momentum=momentum,
    weight_decay=weight_decay,
)

# 学習率のスケジューラ（コサインアニーリング）
scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs)

"""
# Train the model
losses = []
val_accuracies = []

total_step = len(train_loader)
for epoch in range(num_epochs):
    resnet_slth.train()
    epoch_losses = []
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = resnet_slth(images)
        loss = criterion(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_losses.append(loss.item())

            

    # 学習率の更新
    scheduler.step()
    epoch_loss = sum(epoch_losses) / len(epoch_losses)
    losses.append(epoch_loss)

    
    # エポックごとにテストデータでモデルを評価
    resnet_slth.eval()  # 評価モードに設定
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = resnet_slth(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        acc = 100 * correct / total
        print(f"{epoch}: ACC {acc}")
        val_accuracies.append(acc)
"""

In [None]:
losses = []
val_accuracies = []

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class RLAgent(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(RLAgent, self).__init__()
        self.fc1 = nn.Linear(state_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, action_dim)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

class PolicyGradientAgent:
    def __init__(self, state_dim, action_dim, lr=0.001, device='cuda'):
        self.device = device
        self.policy = RLAgent(state_dim, action_dim).to(device)
        self.optimizer = optim.Adam(self.policy.parameters(), lr=lr)
        self.gamma = 0.99
        
    def select_action(self, state):
        if not isinstance(state, torch.Tensor):
            state = torch.FloatTensor(state)
        state = state.to(self.device)
        action = self.policy(state)
        # タンジェントハイパーボリック関数を使用して、-1から1の範囲に制限
        return torch.tanh(action).detach()

    def update(self, states, actions, rewards):
        states = torch.stack(states).to(self.device)
        actions = torch.stack(actions).to(self.device)
        rewards = torch.FloatTensor(rewards).to(self.device).unsqueeze(1)
        
        # Normalize rewards
        rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
        
        # Calculate loss
        log_probs = self.policy(states)
        loss = -torch.mean(log_probs * actions * rewards)
        
        # Update policy
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
 

def is_target_layer(name):
    target_layers = ['fc.scores']
    return any(layer in name for layer in target_layers)

def get_all_scores(model):
    scores = []
    for name, param in model.named_parameters():
        if 'scores' in name and is_target_layer(name):
            scores.append(param.data.view(-1).cpu())  # Convert to CPU
    return torch.cat(scores)

def update_scores(model, score_changes, learning_rate=0.01):
    idx = 0
    for name, param in model.named_parameters():
        if 'scores' in name and is_target_layer(name):
            num_params = param.numel()
            # スコアの変更を現在のスコアに加える
            param.data += learning_rate * score_changes[idx:idx+num_params].view(param.shape)
            idx += num_params

def get_target_params(model):
    return sum(p.numel() for n, p in model.named_parameters() if 'scores' in n and is_target_layer(n))

def freeze_non_target_scores(model):
    for name, param in model.named_parameters():
        if 'scores' in name and not is_target_layer(name):
            param.requires_grad = False

freeze_non_target_scores(resnet_slth)
total_params = get_target_params(resnet_slth)
rl_agent = PolicyGradientAgent(state_dim=total_params, action_dim=total_params, device=device)
def compute_accuracy(outputs, labels):
    _, predicted = torch.max(outputs.data, 1)
    correct = (predicted == labels).sum().item()
    total = labels.size(0)
    return correct / total



for epoch in range(num_epochs):
    resnet_slth.train()
    epoch_losses = []
    states = []
    actions = []
    rewards = []
    
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        # Get current state (all scores)
        state = get_all_scores(resnet_slth)

        # Select action using RL agent
        action = rl_agent.select_action(state)

        # Select action using RL agent
        score_changes = rl_agent.select_action(state)

        # Update scores in the model
        update_scores(resnet_slth, score_changes.to(device))
        
        # Forward pass
        outputs = resnet_slth(images)
        loss = criterion(outputs, labels)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_losses.append(loss.item())
        accuracy = compute_accuracy(outputs, labels)
        
        # Store state, action, and reward for RL update
        states.append(state)
        actions.append(action)
        rewards.append(accuracy)

        if (i+1) % 100 == 0:  # Update RL agent every 100 steps
            rl_agent.update(states, actions, rewards)
            states, actions, rewards = [], [], []

    # ... (残りの部分は変更なし)

    # 学習率の更新
    scheduler.step()
    epoch_loss = sum(epoch_losses) / len(epoch_losses)
    losses.append(epoch_loss)

    # エポックごとにテストデータでモデルを評価
    resnet_slth.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = resnet_slth(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        acc = 100 * correct / total
        print(f"Epoch {epoch}: Loss {epoch_loss:.4f}, ACC {acc:.2f}%")
        val_accuracies.append(acc)